001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.wake.remote.transport.netty; 020 021import io.netty.bootstrap.Bootstrap; 022import io.netty.bootstrap.ServerBootstrap; 023import io.netty.channel.Channel; 024import io.netty.channel.ChannelFuture; 025import io.netty.channel.ChannelOption; 026import io.netty.channel.EventLoopGroup; 027import io.netty.channel.group.ChannelGroup; 028import io.netty.channel.group.ChannelGroupFuture; 029import io.netty.channel.group.DefaultChannelGroup; 030import io.netty.channel.nio.NioEventLoopGroup; 031import io.netty.channel.socket.nio.NioServerSocketChannel; 032import io.netty.channel.socket.nio.NioSocketChannel; 033import io.netty.util.concurrent.Future; 034import io.netty.util.concurrent.GlobalEventExecutor; 035import org.apache.reef.tang.annotations.Parameter; 036import org.apache.reef.wake.EStage; 037import org.apache.reef.wake.EventHandler; 038import org.apache.reef.wake.impl.DefaultThreadFactory; 039import org.apache.reef.wake.remote.Encoder; 040import org.apache.reef.wake.remote.RemoteConfiguration; 041import org.apache.reef.wake.remote.address.LocalAddressProvider; 042import org.apache.reef.wake.remote.exception.RemoteRuntimeException; 043import org.apache.reef.wake.remote.impl.TransportEvent; 044import org.apache.reef.wake.remote.ports.TcpPortProvider; 045import org.apache.reef.wake.remote.transport.Link; 046import org.apache.reef.wake.remote.transport.LinkListener; 047import org.apache.reef.wake.remote.transport.Transport; 048import org.apache.reef.wake.remote.transport.exception.TransportRuntimeException; 049 050import javax.inject.Inject; 051import java.io.IOException; 052import java.net.BindException; 053import java.net.ConnectException; 054import java.net.InetSocketAddress; 055import java.net.SocketAddress; 056import java.util.ArrayList; 057import java.util.Iterator; 058import java.util.concurrent.ConcurrentHashMap; 059import java.util.concurrent.ConcurrentMap; 060import java.util.concurrent.atomic.AtomicInteger; 061import java.util.logging.Level; 062import java.util.logging.Logger; 063 064/** 065 * Messaging transport implementation with Netty. 066 */ 067public final class NettyMessagingTransport implements Transport { 068 069 /** 070 * Indicates a hostname that isn't set or known. 071 */ 072 public static final String UNKNOWN_HOST_NAME = "##UNKNOWN##"; 073 074 private static final String CLASS_NAME = NettyMessagingTransport.class.getSimpleName(); 075 076 private static final Logger LOG = Logger.getLogger(CLASS_NAME); 077 078 private static final int SERVER_BOSS_NUM_THREADS = 3; 079 private static final int SERVER_WORKER_NUM_THREADS = 20; 080 private static final int CLIENT_WORKER_NUM_THREADS = 10; 081 082 private final ConcurrentMap<SocketAddress, LinkReference> addrToLinkRefMap = new ConcurrentHashMap<>(); 083 084 private final EventLoopGroup clientWorkerGroup; 085 private final EventLoopGroup serverBossGroup; 086 private final EventLoopGroup serverWorkerGroup; 087 088 private final Bootstrap clientBootstrap; 089 private final ServerBootstrap serverBootstrap; 090 private final Channel acceptor; 091 092 private final ChannelGroup clientChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); 093 private final ChannelGroup serverChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); 094 095 private final int serverPort; 096 private final SocketAddress localAddress; 097 098 private final NettyClientEventListener clientEventListener; 099 private final NettyServerEventListener serverEventListener; 100 101 private final int numberOfTries; 102 private final int retryTimeout; 103 104 /** 105 * Constructs a messaging transport. 106 * 107 * @param hostAddress the server host address 108 * @param port the server listening port; when it is 0, randomly assign a port number 109 * @param clientStage the client-side stage that handles transport events 110 * @param serverStage the server-side stage that handles transport events 111 * @param numberOfTries the number of tries of connection 112 * @param retryTimeout the timeout of reconnection 113 * @param tcpPortProvider gives an iterator that produces random tcp ports in a range 114 */ 115 @Inject 116 private NettyMessagingTransport( 117 @Parameter(RemoteConfiguration.HostAddress.class) final String hostAddress, 118 @Parameter(RemoteConfiguration.Port.class) final int port, 119 @Parameter(RemoteConfiguration.RemoteClientStage.class) final EStage<TransportEvent> clientStage, 120 @Parameter(RemoteConfiguration.RemoteServerStage.class) final EStage<TransportEvent> serverStage, 121 @Parameter(RemoteConfiguration.NumberOfTries.class) final int numberOfTries, 122 @Parameter(RemoteConfiguration.RetryTimeout.class) final int retryTimeout, 123 final TcpPortProvider tcpPortProvider, 124 final LocalAddressProvider localAddressProvider) { 125 126 int p = port; 127 if (p < 0) { 128 throw new RemoteRuntimeException("Invalid server port: " + p); 129 } 130 131 final String host = UNKNOWN_HOST_NAME.equals(hostAddress) ? localAddressProvider.getLocalAddress() : hostAddress; 132 133 this.numberOfTries = numberOfTries; 134 this.retryTimeout = retryTimeout; 135 this.clientEventListener = new NettyClientEventListener(this.addrToLinkRefMap, clientStage); 136 this.serverEventListener = new NettyServerEventListener(this.addrToLinkRefMap, serverStage); 137 138 this.serverBossGroup = new NioEventLoopGroup(SERVER_BOSS_NUM_THREADS, 139 new DefaultThreadFactory(CLASS_NAME + ":ServerBoss")); 140 this.serverWorkerGroup = new NioEventLoopGroup(SERVER_WORKER_NUM_THREADS, 141 new DefaultThreadFactory(CLASS_NAME + ":ServerWorker")); 142 this.clientWorkerGroup = new NioEventLoopGroup(CLIENT_WORKER_NUM_THREADS, 143 new DefaultThreadFactory(CLASS_NAME + ":ClientWorker")); 144 145 this.clientBootstrap = new Bootstrap(); 146 this.clientBootstrap.group(this.clientWorkerGroup) 147 .channel(NioSocketChannel.class) 148 .handler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("client", 149 this.clientChannelGroup, this.clientEventListener))) 150 .option(ChannelOption.SO_REUSEADDR, true) 151 .option(ChannelOption.SO_KEEPALIVE, true); 152 153 this.serverBootstrap = new ServerBootstrap(); 154 this.serverBootstrap.group(this.serverBossGroup, this.serverWorkerGroup) 155 .channel(NioServerSocketChannel.class) 156 .childHandler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("server", 157 this.serverChannelGroup, this.serverEventListener))) 158 .option(ChannelOption.SO_BACKLOG, 128) 159 .option(ChannelOption.SO_REUSEADDR, true) 160 .childOption(ChannelOption.SO_KEEPALIVE, true); 161 162 LOG.log(Level.FINE, "Binding to {0}", p); 163 164 Channel acceptorFound = null; 165 try { 166 if (p > 0) { 167 acceptorFound = this.serverBootstrap.bind(new InetSocketAddress(host, p)).sync().channel(); 168 } else { 169 final Iterator<Integer> ports = tcpPortProvider.iterator(); 170 while (acceptorFound == null) { 171 if (!ports.hasNext()) { 172 throw new IllegalStateException("tcpPortProvider cannot find a free port."); 173 } 174 p = ports.next(); 175 LOG.log(Level.FINEST, "Try port {0}", p); 176 try { 177 acceptorFound = this.serverBootstrap.bind(new InetSocketAddress(host, p)).sync().channel(); 178 } catch (final Exception ex) { 179 if (ex instanceof BindException) { 180 LOG.log(Level.FINEST, "The port {0} is already bound. Try again", p); 181 } else { 182 throw ex; 183 } 184 } 185 } 186 } 187 } catch (final IllegalStateException ex) { 188 final RuntimeException transportException = 189 new TransportRuntimeException("tcpPortProvider failed to return free ports.", ex); 190 LOG.log(Level.SEVERE, "Cannot find a free port with " + tcpPortProvider, transportException); 191 192 this.clientWorkerGroup.shutdownGracefully(); 193 this.serverBossGroup.shutdownGracefully(); 194 this.serverWorkerGroup.shutdownGracefully(); 195 throw transportException; 196 197 } catch (final Exception ex) { 198 final RuntimeException transportException = 199 new TransportRuntimeException("Cannot bind to port " + p, ex); 200 LOG.log(Level.SEVERE, "Cannot bind to port " + p, ex); 201 202 this.clientWorkerGroup.shutdownGracefully(); 203 this.serverBossGroup.shutdownGracefully(); 204 this.serverWorkerGroup.shutdownGracefully(); 205 throw transportException; 206 } 207 208 this.acceptor = acceptorFound; 209 this.serverPort = p; 210 this.localAddress = new InetSocketAddress(host, this.serverPort); 211 212 LOG.log(Level.FINE, "Starting netty transport socket address: {0}", this.localAddress); 213 } 214 215 /** 216 * Closes all channels and releases all resources. 217 */ 218 @Override 219 public void close() { 220 221 LOG.log(Level.FINE, "Closing netty transport socket address: {0}", this.localAddress); 222 223 final ChannelGroupFuture clientChannelGroupFuture = this.clientChannelGroup.close(); 224 final ChannelGroupFuture serverChannelGroupFuture = this.serverChannelGroup.close(); 225 final ChannelFuture acceptorFuture = this.acceptor.close(); 226 227 final ArrayList<Future> eventLoopGroupFutures = new ArrayList<>(3); 228 eventLoopGroupFutures.add(this.clientWorkerGroup.shutdownGracefully()); 229 eventLoopGroupFutures.add(this.serverBossGroup.shutdownGracefully()); 230 eventLoopGroupFutures.add(this.serverWorkerGroup.shutdownGracefully()); 231 232 clientChannelGroupFuture.awaitUninterruptibly(); 233 serverChannelGroupFuture.awaitUninterruptibly(); 234 235 try { 236 acceptorFuture.sync(); 237 } catch (final Exception ex) { 238 LOG.log(Level.SEVERE, "Error closing the acceptor channel for " + this.localAddress, ex); 239 } 240 241 for (final Future eventLoopGroupFuture : eventLoopGroupFutures) { 242 eventLoopGroupFuture.awaitUninterruptibly(); 243 } 244 245 LOG.log(Level.FINE, "Closing netty transport socket address: {0} done", this.localAddress); 246 } 247 248 /** 249 * Returns a link for the remote address if cached; otherwise opens, caches and returns. 250 * When it opens a link for the remote address, only one attempt for the address is made at a given time 251 * 252 * @param remoteAddr the remote socket address 253 * @param encoder the encoder 254 * @param listener the link listener 255 * @return a link associated with the address 256 */ 257 @Override 258 public <T> Link<T> open(final SocketAddress remoteAddr, final Encoder<? super T> encoder, 259 final LinkListener<? super T> listener) throws IOException { 260 261 Link<T> link = null; 262 263 for (int i = 0; i <= this.numberOfTries; ++i) { 264 LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr); 265 266 if (linkRef != null) { 267 link = (Link<T>) linkRef.getLink(); 268 if (LOG.isLoggable(Level.FINE)) { 269 LOG.log(Level.FINE, "Link {0} for {1} found", new Object[]{link, remoteAddr}); 270 } 271 if (link != null) { 272 return link; 273 } 274 } 275 276 if (i == this.numberOfTries) { 277 // Connection failure 278 throw new ConnectException("Connection to " + remoteAddr + " refused"); 279 } 280 281 LOG.log(Level.FINE, "No cached link for {0} thread {1}", 282 new Object[]{remoteAddr, Thread.currentThread()}); 283 284 // no linkRef 285 final LinkReference newLinkRef = new LinkReference(); 286 final LinkReference prior = this.addrToLinkRefMap.putIfAbsent(remoteAddr, newLinkRef); 287 final AtomicInteger flag = prior != null ? 288 prior.getConnectInProgress() : newLinkRef.getConnectInProgress(); 289 290 synchronized (flag) { 291 if (!flag.compareAndSet(0, 1)) { 292 while (flag.get() == 1) { 293 try { 294 flag.wait(); 295 } catch (final InterruptedException ex) { 296 LOG.log(Level.WARNING, "Wait interrupted", ex); 297 } 298 } 299 } 300 } 301 302 linkRef = this.addrToLinkRefMap.get(remoteAddr); 303 link = (Link<T>) linkRef.getLink(); 304 305 if (link != null) { 306 return link; 307 } 308 309 ChannelFuture connectFuture = null; 310 try { 311 connectFuture = this.clientBootstrap.connect(remoteAddr); 312 connectFuture.syncUninterruptibly(); 313 314 link = new NettyLink<>(connectFuture.channel(), encoder, listener); 315 linkRef.setLink(link); 316 317 synchronized (flag) { 318 flag.compareAndSet(1, 2); 319 flag.notifyAll(); 320 } 321 break; 322 } catch (final Exception e) { 323 if (e.getClass().getSimpleName().compareTo("ConnectException") == 0) { 324 LOG.log(Level.WARNING, "Connection refused. Retry {0} of {1}", 325 new Object[]{i + 1, this.numberOfTries}); 326 synchronized (flag) { 327 flag.compareAndSet(1, 0); 328 flag.notifyAll(); 329 } 330 331 if (i < this.numberOfTries) { 332 try { 333 Thread.sleep(retryTimeout); 334 } catch (final InterruptedException interrupt) { 335 LOG.log(Level.WARNING, "Thread {0} interrupted while sleeping", Thread.currentThread()); 336 } 337 } 338 } else { 339 throw e; 340 } 341 } 342 } 343 344 return link; 345 } 346 347 /** 348 * Returns a link for the remote address if already cached; otherwise, returns null. 349 * 350 * @param remoteAddr the remote address 351 * @return a link if already cached; otherwise, null 352 */ 353 public <T> Link<T> get(final SocketAddress remoteAddr) { 354 final LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr); 355 return linkRef != null ? (Link<T>) linkRef.getLink() : null; 356 } 357 358 /** 359 * Gets a server local socket address of this transport. 360 * 361 * @return a server local socket address 362 */ 363 @Override 364 public SocketAddress getLocalAddress() { 365 return this.localAddress; 366 } 367 368 /** 369 * Gets a server listening port of this transport. 370 * 371 * @return a listening port number 372 */ 373 @Override 374 public int getListeningPort() { 375 return this.serverPort; 376 } 377 378 /** 379 * Registers the exception event handler. 380 * 381 * @param handler the exception event handler 382 */ 383 @Override 384 public void registerErrorHandler(final EventHandler<Exception> handler) { 385 this.clientEventListener.registerErrorHandler(handler); 386 this.serverEventListener.registerErrorHandler(handler); 387 } 388}