This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.wake.remote.transport.netty;
020
021import io.netty.bootstrap.Bootstrap;
022import io.netty.bootstrap.ServerBootstrap;
023import io.netty.channel.Channel;
024import io.netty.channel.ChannelFuture;
025import io.netty.channel.ChannelOption;
026import io.netty.channel.EventLoopGroup;
027import io.netty.channel.group.ChannelGroup;
028import io.netty.channel.group.ChannelGroupFuture;
029import io.netty.channel.group.DefaultChannelGroup;
030import io.netty.channel.nio.NioEventLoopGroup;
031import io.netty.channel.socket.nio.NioServerSocketChannel;
032import io.netty.channel.socket.nio.NioSocketChannel;
033import io.netty.util.concurrent.Future;
034import io.netty.util.concurrent.GlobalEventExecutor;
035import org.apache.reef.tang.annotations.Parameter;
036import org.apache.reef.wake.EStage;
037import org.apache.reef.wake.EventHandler;
038import org.apache.reef.wake.impl.DefaultThreadFactory;
039import org.apache.reef.wake.remote.Encoder;
040import org.apache.reef.wake.remote.RemoteConfiguration;
041import org.apache.reef.wake.remote.address.LocalAddressProvider;
042import org.apache.reef.wake.remote.exception.RemoteRuntimeException;
043import org.apache.reef.wake.remote.impl.TransportEvent;
044import org.apache.reef.wake.remote.ports.TcpPortProvider;
045import org.apache.reef.wake.remote.transport.Link;
046import org.apache.reef.wake.remote.transport.LinkListener;
047import org.apache.reef.wake.remote.transport.Transport;
048import org.apache.reef.wake.remote.transport.exception.TransportRuntimeException;
049
050import javax.inject.Inject;
051import java.io.IOException;
052import java.net.BindException;
053import java.net.ConnectException;
054import java.net.InetSocketAddress;
055import java.net.SocketAddress;
056import java.util.ArrayList;
057import java.util.Iterator;
058import java.util.concurrent.ConcurrentHashMap;
059import java.util.concurrent.ConcurrentMap;
060import java.util.concurrent.atomic.AtomicInteger;
061import java.util.logging.Level;
062import java.util.logging.Logger;
063
064/**
065 * Messaging transport implementation with Netty.
066 */
067public final class NettyMessagingTransport implements Transport {
068
069  /**
070   * Indicates a hostname that isn't set or known.
071   */
072  public static final String UNKNOWN_HOST_NAME = "##UNKNOWN##";
073
074  private static final String CLASS_NAME = NettyMessagingTransport.class.getSimpleName();
075
076  private static final Logger LOG = Logger.getLogger(CLASS_NAME);
077
078  private static final int SERVER_BOSS_NUM_THREADS = 3;
079  private static final int SERVER_WORKER_NUM_THREADS = 20;
080  private static final int CLIENT_WORKER_NUM_THREADS = 10;
081
082  private final ConcurrentMap<SocketAddress, LinkReference> addrToLinkRefMap = new ConcurrentHashMap<>();
083
084  private final EventLoopGroup clientWorkerGroup;
085  private final EventLoopGroup serverBossGroup;
086  private final EventLoopGroup serverWorkerGroup;
087
088  private final Bootstrap clientBootstrap;
089  private final ServerBootstrap serverBootstrap;
090  private final Channel acceptor;
091
092  private final ChannelGroup clientChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
093  private final ChannelGroup serverChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
094
095  private final int serverPort;
096  private final SocketAddress localAddress;
097
098  private final NettyClientEventListener clientEventListener;
099  private final NettyServerEventListener serverEventListener;
100
101  private final int numberOfTries;
102  private final int retryTimeout;
103
104  /**
105   * Constructs a messaging transport.
106   *
107   * @param hostAddress   the server host address
108   * @param port          the server listening port; when it is 0, randomly assign a port number
109   * @param clientStage   the client-side stage that handles transport events
110   * @param serverStage   the server-side stage that handles transport events
111   * @param numberOfTries the number of tries of connection
112   * @param retryTimeout  the timeout of reconnection
113   * @param tcpPortProvider  gives an iterator that produces random tcp ports in a range
114   */
115  @Inject
116  private NettyMessagingTransport(
117      @Parameter(RemoteConfiguration.HostAddress.class) final String hostAddress,
118      @Parameter(RemoteConfiguration.Port.class) final int port,
119      @Parameter(RemoteConfiguration.RemoteClientStage.class) final EStage<TransportEvent> clientStage,
120      @Parameter(RemoteConfiguration.RemoteServerStage.class) final EStage<TransportEvent> serverStage,
121      @Parameter(RemoteConfiguration.NumberOfTries.class) final int numberOfTries,
122      @Parameter(RemoteConfiguration.RetryTimeout.class) final int retryTimeout,
123      final TcpPortProvider tcpPortProvider,
124      final LocalAddressProvider localAddressProvider) {
125
126    int p = port;
127    if (p < 0) {
128      throw new RemoteRuntimeException("Invalid server port: " + p);
129    }
130
131    final String host = UNKNOWN_HOST_NAME.equals(hostAddress) ? localAddressProvider.getLocalAddress() : hostAddress;
132
133    this.numberOfTries = numberOfTries;
134    this.retryTimeout = retryTimeout;
135    this.clientEventListener = new NettyClientEventListener(this.addrToLinkRefMap, clientStage);
136    this.serverEventListener = new NettyServerEventListener(this.addrToLinkRefMap, serverStage);
137
138    this.serverBossGroup = new NioEventLoopGroup(SERVER_BOSS_NUM_THREADS,
139        new DefaultThreadFactory(CLASS_NAME + ":ServerBoss"));
140    this.serverWorkerGroup = new NioEventLoopGroup(SERVER_WORKER_NUM_THREADS,
141        new DefaultThreadFactory(CLASS_NAME + ":ServerWorker"));
142    this.clientWorkerGroup = new NioEventLoopGroup(CLIENT_WORKER_NUM_THREADS,
143        new DefaultThreadFactory(CLASS_NAME + ":ClientWorker"));
144
145    this.clientBootstrap = new Bootstrap();
146    this.clientBootstrap.group(this.clientWorkerGroup)
147        .channel(NioSocketChannel.class)
148        .handler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("client",
149            this.clientChannelGroup, this.clientEventListener)))
150        .option(ChannelOption.SO_REUSEADDR, true)
151        .option(ChannelOption.SO_KEEPALIVE, true);
152
153    this.serverBootstrap = new ServerBootstrap();
154    this.serverBootstrap.group(this.serverBossGroup, this.serverWorkerGroup)
155        .channel(NioServerSocketChannel.class)
156        .childHandler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("server",
157            this.serverChannelGroup, this.serverEventListener)))
158        .option(ChannelOption.SO_BACKLOG, 128)
159        .option(ChannelOption.SO_REUSEADDR, true)
160        .childOption(ChannelOption.SO_KEEPALIVE, true);
161
162    LOG.log(Level.FINE, "Binding to {0}", p);
163
164    Channel acceptorFound = null;
165    try {
166      if (p > 0) {
167        acceptorFound = this.serverBootstrap.bind(new InetSocketAddress(host, p)).sync().channel();
168      } else {
169        final Iterator<Integer> ports = tcpPortProvider.iterator();
170        while (acceptorFound == null) {
171          if (!ports.hasNext()) {
172            throw new IllegalStateException("tcpPortProvider cannot find a free port.");
173          }
174          p = ports.next();
175          LOG.log(Level.FINEST, "Try port {0}", p);
176          try {
177            acceptorFound = this.serverBootstrap.bind(new InetSocketAddress(host, p)).sync().channel();
178          } catch (final Exception ex) {
179            if (ex instanceof BindException) {
180              LOG.log(Level.FINEST, "The port {0} is already bound. Try again", p);
181            } else {
182              throw ex;
183            }
184          }
185        }
186      }
187    } catch (final IllegalStateException ex) {
188      final RuntimeException transportException =
189                new TransportRuntimeException("tcpPortProvider failed to return free ports.", ex);
190      LOG.log(Level.SEVERE, "Cannot find a free port with " + tcpPortProvider, transportException);
191
192      this.clientWorkerGroup.shutdownGracefully();
193      this.serverBossGroup.shutdownGracefully();
194      this.serverWorkerGroup.shutdownGracefully();
195      throw transportException;
196
197    } catch (final Exception ex) {
198      final RuntimeException transportException =
199          new TransportRuntimeException("Cannot bind to port " + p, ex);
200      LOG.log(Level.SEVERE, "Cannot bind to port " + p, ex);
201
202      this.clientWorkerGroup.shutdownGracefully();
203      this.serverBossGroup.shutdownGracefully();
204      this.serverWorkerGroup.shutdownGracefully();
205      throw transportException;
206    }
207
208    this.acceptor = acceptorFound;
209    this.serverPort = p;
210    this.localAddress = new InetSocketAddress(host, this.serverPort);
211
212    LOG.log(Level.FINE, "Starting netty transport socket address: {0}", this.localAddress);
213  }
214
215  /**
216   * Closes all channels and releases all resources.
217   */
218  @Override
219  public void close() {
220
221    LOG.log(Level.FINE, "Closing netty transport socket address: {0}", this.localAddress);
222
223    final ChannelGroupFuture clientChannelGroupFuture = this.clientChannelGroup.close();
224    final ChannelGroupFuture serverChannelGroupFuture = this.serverChannelGroup.close();
225    final ChannelFuture acceptorFuture = this.acceptor.close();
226
227    final ArrayList<Future> eventLoopGroupFutures = new ArrayList<>(3);
228    eventLoopGroupFutures.add(this.clientWorkerGroup.shutdownGracefully());
229    eventLoopGroupFutures.add(this.serverBossGroup.shutdownGracefully());
230    eventLoopGroupFutures.add(this.serverWorkerGroup.shutdownGracefully());
231
232    clientChannelGroupFuture.awaitUninterruptibly();
233    serverChannelGroupFuture.awaitUninterruptibly();
234
235    try {
236      acceptorFuture.sync();
237    } catch (final Exception ex) {
238      LOG.log(Level.SEVERE, "Error closing the acceptor channel for " + this.localAddress, ex);
239    }
240
241    for (final Future eventLoopGroupFuture : eventLoopGroupFutures) {
242      eventLoopGroupFuture.awaitUninterruptibly();
243    }
244
245    LOG.log(Level.FINE, "Closing netty transport socket address: {0} done", this.localAddress);
246  }
247
248  /**
249   * Returns a link for the remote address if cached; otherwise opens, caches and returns.
250   * When it opens a link for the remote address, only one attempt for the address is made at a given time
251   *
252   * @param remoteAddr the remote socket address
253   * @param encoder    the encoder
254   * @param listener   the link listener
255   * @return a link associated with the address
256   */
257  @Override
258  public <T> Link<T> open(final SocketAddress remoteAddr, final Encoder<? super T> encoder,
259                          final LinkListener<? super T> listener) throws IOException {
260
261    Link<T> link = null;
262
263    for (int i = 0; i <= this.numberOfTries; ++i) {
264      LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr);
265
266      if (linkRef != null) {
267        link = (Link<T>) linkRef.getLink();
268        if (LOG.isLoggable(Level.FINE)) {
269          LOG.log(Level.FINE, "Link {0} for {1} found", new Object[]{link, remoteAddr});
270        }
271        if (link != null) {
272          return link;
273        }
274      }
275      
276      if (i == this.numberOfTries) {
277        // Connection failure 
278        throw new ConnectException("Connection to " + remoteAddr + " refused");
279      }
280
281      LOG.log(Level.FINE, "No cached link for {0} thread {1}",
282          new Object[]{remoteAddr, Thread.currentThread()});
283
284      // no linkRef
285      final LinkReference newLinkRef = new LinkReference();
286      final LinkReference prior = this.addrToLinkRefMap.putIfAbsent(remoteAddr, newLinkRef);
287      final AtomicInteger flag = prior != null ?
288          prior.getConnectInProgress() : newLinkRef.getConnectInProgress();
289
290      synchronized (flag) {
291        if (!flag.compareAndSet(0, 1)) {
292          while (flag.get() == 1) {
293            try {
294              flag.wait();
295            } catch (final InterruptedException ex) {
296              LOG.log(Level.WARNING, "Wait interrupted", ex);
297            }
298          }
299        }
300      }
301
302      linkRef = this.addrToLinkRefMap.get(remoteAddr);
303      link = (Link<T>) linkRef.getLink();
304
305      if (link != null) {
306        return link;
307      }
308
309      ChannelFuture connectFuture = null;
310      try {
311        connectFuture = this.clientBootstrap.connect(remoteAddr);
312        connectFuture.syncUninterruptibly();
313
314        link = new NettyLink<>(connectFuture.channel(), encoder, listener);
315        linkRef.setLink(link);
316
317        synchronized (flag) {
318          flag.compareAndSet(1, 2);
319          flag.notifyAll();
320        }
321        break;
322      } catch (final Exception e) {
323        if (e.getClass().getSimpleName().compareTo("ConnectException") == 0) {
324          LOG.log(Level.WARNING, "Connection refused. Retry {0} of {1}",
325              new Object[]{i + 1, this.numberOfTries});
326          synchronized (flag) {
327            flag.compareAndSet(1, 0);
328            flag.notifyAll();
329          }
330
331          if (i < this.numberOfTries) {
332            try {
333              Thread.sleep(retryTimeout);
334            } catch (final InterruptedException interrupt) {
335              LOG.log(Level.WARNING, "Thread {0} interrupted while sleeping", Thread.currentThread());
336            }
337          }
338        } else {
339          throw e;
340        }
341      }
342    }
343    
344    return link;
345  }
346
347  /**
348   * Returns a link for the remote address if already cached; otherwise, returns null.
349   *
350   * @param remoteAddr the remote address
351   * @return a link if already cached; otherwise, null
352   */
353  public <T> Link<T> get(final SocketAddress remoteAddr) {
354    final LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr);
355    return linkRef != null ? (Link<T>) linkRef.getLink() : null;
356  }
357
358  /**
359   * Gets a server local socket address of this transport.
360   *
361   * @return a server local socket address
362   */
363  @Override
364  public SocketAddress getLocalAddress() {
365    return this.localAddress;
366  }
367
368  /**
369   * Gets a server listening port of this transport.
370   *
371   * @return a listening port number
372   */
373  @Override
374  public int getListeningPort() {
375    return this.serverPort;
376  }
377
378  /**
379   * Registers the exception event handler.
380   *
381   * @param handler the exception event handler
382   */
383  @Override
384  public void registerErrorHandler(final EventHandler<Exception> handler) {
385    this.clientEventListener.registerErrorHandler(handler);
386    this.serverEventListener.registerErrorHandler(handler);
387  }
388}