This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.io.network.group.impl.task;
020
021import org.apache.reef.driver.task.TaskConfigurationOptions;
022import org.apache.reef.exception.evaluator.NetworkException;
023import org.apache.reef.io.network.group.api.operators.*;
024import org.apache.reef.io.network.group.impl.config.parameters.DriverIdentifierGroupComm;
025import org.apache.reef.io.network.group.impl.driver.TopologySimpleNode;
026import org.apache.reef.io.network.group.impl.driver.TopologySerializer;
027import org.apache.reef.io.network.impl.NetworkService;
028import org.apache.reef.io.network.group.api.GroupChanges;
029import org.apache.reef.io.network.group.api.task.CommGroupNetworkHandler;
030import org.apache.reef.io.network.group.api.task.CommunicationGroupServiceClient;
031import org.apache.reef.io.network.group.api.task.GroupCommNetworkHandler;
032import org.apache.reef.io.network.group.impl.GroupChangesCodec;
033import org.apache.reef.io.network.group.impl.GroupChangesImpl;
034import org.apache.reef.io.network.group.impl.GroupCommunicationMessage;
035import org.apache.reef.io.network.group.impl.config.parameters.CommunicationGroupName;
036import org.apache.reef.io.network.group.impl.config.parameters.OperatorName;
037import org.apache.reef.io.network.group.impl.config.parameters.SerializedOperConfigs;
038import org.apache.reef.io.network.group.impl.operators.Sender;
039import org.apache.reef.io.network.group.impl.utils.Utils;
040import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos;
041import org.apache.reef.io.network.util.Pair;
042import org.apache.reef.io.serialization.Codec;
043import org.apache.reef.tang.Configuration;
044import org.apache.reef.tang.Injector;
045import org.apache.reef.tang.Tang;
046import org.apache.reef.tang.annotations.Name;
047import org.apache.reef.tang.annotations.Parameter;
048import org.apache.reef.tang.exceptions.InjectionException;
049import org.apache.reef.tang.formats.ConfigurationSerializer;
050import org.apache.reef.wake.EStage;
051import org.apache.reef.wake.Identifier;
052import org.apache.reef.wake.IdentifierFactory;
053import org.apache.reef.wake.impl.ThreadPoolStage;
054
055import javax.inject.Inject;
056import java.io.IOException;
057import java.util.*;
058import java.util.concurrent.CountDownLatch;
059import java.util.concurrent.atomic.AtomicBoolean;
060import java.util.logging.Logger;
061
062public class CommunicationGroupClientImpl implements CommunicationGroupServiceClient {
063  private static final Logger LOG = Logger.getLogger(CommunicationGroupClientImpl.class.getName());
064
065  private final GroupCommNetworkHandler groupCommNetworkHandler;
066  private final Class<? extends Name<String>> groupName;
067  private final Map<Class<? extends Name<String>>, GroupCommOperator> operators;
068  private final Sender sender;
069
070  private final String taskId;
071  private final boolean isScatterSender;
072  private final IdentifierFactory identifierFactory;
073  private List<Identifier> activeSlaveTasks;
074  private TopologySimpleNode topologySimpleNodeRoot;
075
076  private final String driverId;
077
078  private final CommGroupNetworkHandler commGroupNetworkHandler;
079
080  private final AtomicBoolean init = new AtomicBoolean(false);
081
082  /**
083   * @deprecated in 0.14.
084   * Use the private constructor that receives an {@code injector} as a parameter instead.
085   */
086  @Deprecated
087  @Inject
088  public CommunicationGroupClientImpl(@Parameter(CommunicationGroupName.class) final String groupName,
089                                      @Parameter(TaskConfigurationOptions.Identifier.class) final String taskId,
090                                      @Parameter(DriverIdentifierGroupComm.class) final String driverId,
091                                      final GroupCommNetworkHandler groupCommNetworkHandler,
092                                      @Parameter(SerializedOperConfigs.class) final Set<String> operatorConfigs,
093                                      final ConfigurationSerializer configSerializer,
094                                      final NetworkService<GroupCommunicationMessage> netService) {
095    this.taskId = taskId;
096    this.driverId = driverId;
097    LOG.finest(groupName + " has GroupCommHandler-" + groupCommNetworkHandler.toString());
098    this.identifierFactory = netService.getIdentifierFactory();
099    this.groupName = Utils.getClass(groupName);
100    this.groupCommNetworkHandler = groupCommNetworkHandler;
101    this.sender = new Sender(netService);
102    this.operators = new TreeMap<>(new Comparator<Class<? extends Name<String>>>() {
103
104      @Override
105      public int compare(final Class<? extends Name<String>> o1, final Class<? extends Name<String>> o2) {
106        final String s1 = o1.getSimpleName();
107        final String s2 = o2.getSimpleName();
108        return s1.compareTo(s2);
109      }
110    });
111    try {
112      this.commGroupNetworkHandler = Tang.Factory.getTang().newInjector().getInstance(CommGroupNetworkHandler.class);
113      this.groupCommNetworkHandler.register(this.groupName, commGroupNetworkHandler);
114
115      boolean operatorIsScatterSender = false;
116      for (final String operatorConfigStr : operatorConfigs) {
117
118        final Configuration operatorConfig = configSerializer.fromString(operatorConfigStr);
119        final Injector injector = Tang.Factory.getTang().newInjector(operatorConfig);
120
121        injector.bindVolatileParameter(TaskConfigurationOptions.Identifier.class, taskId);
122        injector.bindVolatileParameter(CommunicationGroupName.class, groupName);
123        injector.bindVolatileInstance(CommGroupNetworkHandler.class, commGroupNetworkHandler);
124        injector.bindVolatileInstance(NetworkService.class, netService);
125        injector.bindVolatileInstance(CommunicationGroupServiceClient.class, this);
126
127        final GroupCommOperator operator = injector.getInstance(GroupCommOperator.class);
128        final String operName = injector.getNamedInstance(OperatorName.class);
129        this.operators.put(Utils.getClass(operName), operator);
130        LOG.finest(operName + " has CommGroupHandler-" + commGroupNetworkHandler.toString());
131
132        if (!operatorIsScatterSender && operator instanceof Scatter.Sender) {
133          LOG.fine(operName + " is a scatter sender. Will keep track of active slave tasks.");
134          operatorIsScatterSender = true;
135        }
136      }
137      this.isScatterSender = operatorIsScatterSender;
138    } catch (final InjectionException | IOException e) {
139      throw new RuntimeException("Unable to deserialize operator config", e);
140    }
141  }
142
143  @Inject
144  private CommunicationGroupClientImpl(@Parameter(CommunicationGroupName.class) final String groupName,
145                                      @Parameter(TaskConfigurationOptions.Identifier.class) final String taskId,
146                                      @Parameter(DriverIdentifierGroupComm.class) final String driverId,
147                                      final GroupCommNetworkHandler groupCommNetworkHandler,
148                                      @Parameter(SerializedOperConfigs.class) final Set<String> operatorConfigs,
149                                      final ConfigurationSerializer configSerializer,
150                                      final NetworkService<GroupCommunicationMessage> netService,
151                                      final CommGroupNetworkHandler commGroupNetworkHandler,
152                                      final Injector injector) {
153    this.taskId = taskId;
154    this.driverId = driverId;
155    LOG.finest(groupName + " has GroupCommHandler-" + groupCommNetworkHandler.toString());
156    this.identifierFactory = netService.getIdentifierFactory();
157    this.groupName = Utils.getClass(groupName);
158    this.groupCommNetworkHandler = groupCommNetworkHandler;
159    this.commGroupNetworkHandler = commGroupNetworkHandler;
160    this.sender = new Sender(netService);
161    this.operators = new TreeMap<>(new Comparator<Class<? extends Name<String>>>() {
162
163      @Override
164      public int compare(final Class<? extends Name<String>> o1, final Class<? extends Name<String>> o2) {
165        final String s1 = o1.getSimpleName();
166        final String s2 = o2.getSimpleName();
167        return s1.compareTo(s2);
168      }
169    });
170    try {
171      this.groupCommNetworkHandler.register(this.groupName, commGroupNetworkHandler);
172
173      boolean operatorIsScatterSender = false;
174      for (final String operatorConfigStr : operatorConfigs) {
175
176        final Configuration operatorConfig = configSerializer.fromString(operatorConfigStr);
177        final Injector forkedInjector = injector.forkInjector(operatorConfig);
178
179        forkedInjector.bindVolatileInstance(CommunicationGroupServiceClient.class, this);
180
181        final GroupCommOperator operator = forkedInjector.getInstance(GroupCommOperator.class);
182        final String operName = forkedInjector.getNamedInstance(OperatorName.class);
183        this.operators.put(Utils.getClass(operName), operator);
184        LOG.finest(operName + " has CommGroupHandler-" + commGroupNetworkHandler.toString());
185
186        if (!operatorIsScatterSender && operator instanceof Scatter.Sender) {
187          LOG.fine(operName + " is a scatter sender. Will keep track of active slave tasks.");
188          operatorIsScatterSender = true;
189        }
190      }
191      this.isScatterSender = operatorIsScatterSender;
192    } catch (final InjectionException | IOException e) {
193      throw new RuntimeException("Unable to deserialize operator config", e);
194    }
195  }
196
197  @Override
198  public Broadcast.Sender getBroadcastSender(final Class<? extends Name<String>> operatorName) {
199    LOG.entering("CommunicationGroupClientImpl", "getBroadcastSender", new Object[]{getQualifiedName(),
200        Utils.simpleName(operatorName)});
201    final GroupCommOperator op = operators.get(operatorName);
202    if (!(op instanceof Broadcast.Sender)) {
203      throw new RuntimeException("Configured operator is not a broadcast sender");
204    }
205    commGroupNetworkHandler.addTopologyElement(operatorName);
206    LOG.exiting("CommunicationGroupClientImpl", "getBroadcastSender", getQualifiedName() + op);
207    return (Broadcast.Sender) op;
208  }
209
210  @Override
211  public Reduce.Receiver getReduceReceiver(final Class<? extends Name<String>> operatorName) {
212    LOG.entering("CommunicationGroupClientImpl", "getReduceReceiver", new Object[]{getQualifiedName(),
213        Utils.simpleName(operatorName)});
214    final GroupCommOperator op = operators.get(operatorName);
215    if (!(op instanceof Reduce.Receiver)) {
216      throw new RuntimeException("Configured operator is not a reduce receiver");
217    }
218    commGroupNetworkHandler.addTopologyElement(operatorName);
219    LOG.exiting("CommunicationGroupClientImpl", "getReduceReceiver", getQualifiedName() + op);
220    return (Reduce.Receiver) op;
221  }
222
223  @Override
224  public Scatter.Sender getScatterSender(final Class<? extends Name<String>> operatorName) {
225    LOG.entering("CommunicationGroupClientImpl", "getScatterSender", new Object[]{getQualifiedName(),
226        Utils.simpleName(operatorName)});
227    final GroupCommOperator op = operators.get(operatorName);
228    if (!(op instanceof Scatter.Sender)) {
229      throw new RuntimeException("Configured operator is not a scatter sender");
230    }
231    commGroupNetworkHandler.addTopologyElement(operatorName);
232    LOG.exiting("CommunicationGroupClientImpl", "getScatterSender", getQualifiedName() + op);
233    return (Scatter.Sender) op;
234  }
235
236  @Override
237  public Gather.Receiver getGatherReceiver(final Class<? extends Name<String>> operatorName) {
238    LOG.entering("CommunicationGroupClientImpl", "getGatherReceiver", new Object[]{getQualifiedName(),
239        Utils.simpleName(operatorName)});
240    final GroupCommOperator op = operators.get(operatorName);
241    if (!(op instanceof Gather.Receiver)) {
242      throw new RuntimeException("Configured operator is not a gather receiver");
243    }
244    commGroupNetworkHandler.addTopologyElement(operatorName);
245    LOG.exiting("CommunicationGroupClientImpl", "getGatherReceiver", getQualifiedName() + op);
246    return (Gather.Receiver) op;
247  }
248
249
250  @Override
251  public Broadcast.Receiver getBroadcastReceiver(final Class<? extends Name<String>> operatorName) {
252    LOG.entering("CommunicationGroupClientImpl", "getBroadcastReceiver", new Object[]{getQualifiedName(),
253        Utils.simpleName(operatorName)});
254    final GroupCommOperator op = operators.get(operatorName);
255    if (!(op instanceof Broadcast.Receiver)) {
256      throw new RuntimeException("Configured operator is not a broadcast receiver");
257    }
258    commGroupNetworkHandler.addTopologyElement(operatorName);
259    LOG.exiting("CommunicationGroupClientImpl", "getBroadcastReceiver", getQualifiedName() + op);
260    return (Broadcast.Receiver) op;
261  }
262
263  @Override
264  public Reduce.Sender getReduceSender(final Class<? extends Name<String>> operatorName) {
265    LOG.entering("CommunicationGroupClientImpl", "getReduceSender", new Object[]{getQualifiedName(),
266        Utils.simpleName(operatorName)});
267    final GroupCommOperator op = operators.get(operatorName);
268    if (!(op instanceof Reduce.Sender)) {
269      throw new RuntimeException("Configured operator is not a reduce sender");
270    }
271    commGroupNetworkHandler.addTopologyElement(operatorName);
272    LOG.exiting("CommunicationGroupClientImpl", "getReduceSender", getQualifiedName() + op);
273    return (Reduce.Sender) op;
274  }
275
276  @Override
277  public Scatter.Receiver getScatterReceiver(final Class<? extends Name<String>> operatorName) {
278    LOG.entering("CommunicationGroupClientImpl", "getScatterReceiver", new Object[]{getQualifiedName(),
279        Utils.simpleName(operatorName)});
280    final GroupCommOperator op = operators.get(operatorName);
281    if (!(op instanceof Scatter.Receiver)) {
282      throw new RuntimeException("Configured operator is not a scatter receiver");
283    }
284    commGroupNetworkHandler.addTopologyElement(operatorName);
285    LOG.exiting("CommunicationGroupClientImpl", "getScatterReceiver", getQualifiedName() + op);
286    return (Scatter.Receiver) op;
287  }
288
289  @Override
290  public Gather.Sender getGatherSender(final Class<? extends Name<String>> operatorName) {
291    LOG.entering("CommunicationGroupClientImpl", "getGatherSender", new Object[]{getQualifiedName(),
292        Utils.simpleName(operatorName)});
293    final GroupCommOperator op = operators.get(operatorName);
294    if (!(op instanceof Gather.Sender)) {
295      throw new RuntimeException("Configured operator is not a gather sender");
296    }
297    commGroupNetworkHandler.addTopologyElement(operatorName);
298    LOG.exiting("CommunicationGroupClientImpl", "getGatherSender", getQualifiedName() + op);
299    return (Gather.Sender) op;
300  }
301
302  @Override
303  public void initialize() {
304    LOG.entering("CommunicationGroupClientImpl", "initialize", getQualifiedName());
305    if (init.compareAndSet(false, true)) {
306      LOG.finest("CommGroup-" + groupName + " is initializing");
307      final CountDownLatch initLatch = new CountDownLatch(operators.size());
308
309      final InitHandler initHandler = new InitHandler(initLatch);
310      final EStage<GroupCommOperator> initStage = new ThreadPoolStage<>(initHandler, operators.size());
311      for (final GroupCommOperator op : operators.values()) {
312        initStage.onNext(op);
313      }
314
315      try {
316        initLatch.await();
317      } catch (final InterruptedException e) {
318        throw new RuntimeException("InterruptedException while waiting for initialization", e);
319      }
320
321      if (isScatterSender) {
322        updateTopology();
323      }
324
325      if (initHandler.getException() != null) {
326        throw new RuntimeException(getQualifiedName() + "Parent dead. Current behavior is for the child to die too.");
327      }
328    }
329    LOG.exiting("CommunicationGroupClientImpl", "initialize", getQualifiedName());
330  }
331
332  @Override
333  public GroupChanges getTopologyChanges() {
334    LOG.entering("CommunicationGroupClientImpl", "getTopologyChanges", getQualifiedName());
335    for (final GroupCommOperator op : operators.values()) {
336      final Class<? extends Name<String>> operName = op.getOperName();
337      LOG.finest("Sending TopologyChanges msg to driver");
338      try {
339        sender.send(Utils.bldVersionedGCM(groupName, operName,
340            ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, taskId, op.getVersion(), driverId,
341            0, Utils.EMPTY_BYTE_ARR));
342      } catch (final NetworkException e) {
343        throw new RuntimeException("NetworkException while sending GetTopologyChanges", e);
344      }
345    }
346    final Codec<GroupChanges> changesCodec = new GroupChangesCodec();
347    final Map<Class<? extends Name<String>>, GroupChanges> perOpChanges = new HashMap<>();
348    for (final GroupCommOperator op : operators.values()) {
349      final Class<? extends Name<String>> operName = op.getOperName();
350      final byte[] changes = commGroupNetworkHandler.waitForTopologyChanges(operName);
351      perOpChanges.put(operName, changesCodec.decode(changes));
352    }
353    final GroupChanges retVal = mergeGroupChanges(perOpChanges);
354    LOG.exiting("CommunicationGroupClientImpl", "getTopologyChanges", getQualifiedName() + retVal);
355    return retVal;
356  }
357
358  /**
359   * @param perOpChanges
360   * @return
361   */
362  private GroupChanges mergeGroupChanges(final Map<Class<? extends Name<String>>, GroupChanges> perOpChanges) {
363    LOG.entering("CommunicationGroupClientImpl", "mergeGroupChanges", new Object[]{getQualifiedName(), perOpChanges});
364    boolean doChangesExist = false;
365    for (final GroupChanges change : perOpChanges.values()) {
366      if (change.exist()) {
367        doChangesExist = true;
368        break;
369      }
370    }
371    final GroupChanges changes = new GroupChangesImpl(doChangesExist);
372    LOG.exiting("CommunicationGroupClientImpl", "mergeGroupChanges", getQualifiedName() + changes);
373    return changes;
374  }
375
376  @Override
377  public void updateTopology() {
378    LOG.entering("CommunicationGroupClientImpl", "updateTopology", getQualifiedName());
379    for (final GroupCommOperator op : operators.values()) {
380      final Class<? extends Name<String>> operName = op.getOperName();
381      try {
382        sender.send(Utils.bldVersionedGCM(groupName, operName,
383            ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, taskId, op.getVersion(), driverId,
384            0, Utils.EMPTY_BYTE_ARR));
385      } catch (final NetworkException e) {
386        throw new RuntimeException("NetworkException while sending UpdateTopology", e);
387      }
388    }
389    for (final GroupCommOperator op : operators.values()) {
390      final Class<? extends Name<String>> operName = op.getOperName();
391      GroupCommunicationMessage msg;
392      do {
393        msg = commGroupNetworkHandler.waitForTopologyUpdate(operName);
394      } while (!isMsgVersionOk(msg));
395
396      if (isScatterSender) {
397        updateActiveTasks(msg);
398      }
399    }
400    LOG.exiting("CommunicationGroupClientImpl", "updateTopology", getQualifiedName());
401  }
402
403  private void updateActiveTasks(final GroupCommunicationMessage msg) {
404    LOG.entering("CommunicationGroupClientImpl", "updateActiveTasks", new Object[]{getQualifiedName(), msg});
405
406    final Pair<TopologySimpleNode, List<Identifier>> pair =
407        TopologySerializer.decode(msg.getData()[0], identifierFactory);
408
409    topologySimpleNodeRoot = pair.getFirst();
410
411    activeSlaveTasks = pair.getSecond();
412    // remove myself
413    activeSlaveTasks.remove(identifierFactory.getNewInstance(taskId));
414    // sort the tasks in lexicographical order on task ids
415    Collections.sort(activeSlaveTasks, new Comparator<Identifier>() {
416      @Override
417      public int compare(final Identifier o1, final Identifier o2) {
418        return o1.toString().compareTo(o2.toString());
419      }
420    });
421
422    LOG.exiting("CommunicationGroupClientImpl", "updateActiveTasks", new Object[]{getQualifiedName(), msg});
423  }
424
425  private boolean isMsgVersionOk(final GroupCommunicationMessage msg) {
426    LOG.entering("CommunicationGroupClientImpl", "isMsgVersionOk", new Object[]{getQualifiedName(), msg});
427    if (msg.hasVersion()) {
428      final int msgVersion = msg.getVersion();
429      final GroupCommOperator operator = operators.get(Utils.getClass(msg.getOperatorname()));
430      final int nodeVersion = operator.getVersion();
431      final boolean retVal;
432      if (msgVersion < nodeVersion) {
433        LOG.warning(getQualifiedName() + "Received a ver-" + msgVersion + " msg while expecting ver-" + nodeVersion
434            + ". Discarding msg");
435        retVal = false;
436      } else {
437        retVal = true;
438      }
439      LOG.exiting("CommunicationGroupClientImpl", "isMsgVersionOk",
440          Arrays.toString(new Object[]{retVal, getQualifiedName(), msg}));
441      return retVal;
442    } else {
443      throw new RuntimeException(getQualifiedName() + "can only deal with versioned msgs");
444    }
445  }
446
447  @Override
448  public List<Identifier> getActiveSlaveTasks() {
449    return this.activeSlaveTasks;
450  }
451
452  @Override
453  public TopologySimpleNode getTopologySimpleNodeRoot() {
454    return this.topologySimpleNodeRoot;
455  }
456
457  private String getQualifiedName() {
458    return Utils.simpleName(groupName) + " ";
459  }
460
461  @Override
462  public Class<? extends Name<String>> getName() {
463    return groupName;
464  }
465
466}