This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.io.network.group.impl.driver;
020
021import org.apache.reef.driver.parameters.DriverIdentifier;
022import org.apache.reef.io.network.group.api.operators.GroupCommOperator;
023import org.apache.reef.io.network.group.api.GroupChanges;
024import org.apache.reef.io.network.group.api.config.OperatorSpec;
025import org.apache.reef.io.network.group.api.driver.TaskNode;
026import org.apache.reef.io.network.group.api.driver.Topology;
027import org.apache.reef.io.network.group.impl.GroupChangesCodec;
028import org.apache.reef.io.network.group.impl.GroupChangesImpl;
029import org.apache.reef.io.network.group.impl.GroupCommunicationMessage;
030import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
031import org.apache.reef.io.network.group.impl.config.GatherOperatorSpec;
032import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
033import org.apache.reef.io.network.group.impl.config.ScatterOperatorSpec;
034import org.apache.reef.io.network.group.impl.config.parameters.*;
035import org.apache.reef.io.network.group.impl.operators.*;
036import org.apache.reef.io.network.group.impl.utils.Utils;
037import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos;
038import org.apache.reef.io.serialization.Codec;
039import org.apache.reef.tang.Configuration;
040import org.apache.reef.tang.JavaConfigurationBuilder;
041import org.apache.reef.tang.Tang;
042import org.apache.reef.tang.annotations.Name;
043import org.apache.reef.tang.annotations.Parameter;
044import org.apache.reef.tang.formats.AvroConfigurationSerializer;
045import org.apache.reef.tang.formats.ConfigurationSerializer;
046import org.apache.reef.wake.EStage;
047import org.apache.reef.wake.EventHandler;
048import org.apache.reef.wake.impl.SingleThreadStage;
049
050import javax.inject.Inject;
051import java.util.ArrayList;
052import java.util.List;
053import java.util.Map;
054import java.util.concurrent.ConcurrentMap;
055import java.util.concurrent.ConcurrentSkipListMap;
056import java.util.logging.Logger;
057
058/**
059 * Implements a tree topology with the specified Fan Out.
060 */
061public final class TreeTopology implements Topology {
062
063  private static final Logger LOG = Logger.getLogger(TreeTopology.class.getName());
064
065  private final EStage<GroupCommunicationMessage> senderStage;
066  private final Class<? extends Name<String>> groupName;
067  private final Class<? extends Name<String>> operName;
068  private final String driverId;
069  private String rootId;
070  private OperatorSpec operatorSpec;
071
072  private TaskNode root;
073  private TaskNode logicalRoot;
074  private TaskNode prev;
075  private final int fanOut;
076
077  private final ConcurrentMap<String, TaskNode> nodes = new ConcurrentSkipListMap<>();
078  private final ConfigurationSerializer confSer = new AvroConfigurationSerializer();
079
080  @Inject
081  private TreeTopology(@Parameter(GroupCommSenderStage.class) final EStage<GroupCommunicationMessage> senderStage,
082                       @Parameter(CommGroupNameClass.class) final Class<? extends Name<String>> groupName,
083                       @Parameter(OperatorNameClass.class) final Class<? extends Name<String>> operatorName,
084                       @Parameter(DriverIdentifier.class) final String driverId,
085                       @Parameter(TreeTopologyFanOut.class) final int fanOut) {
086    this.senderStage = senderStage;
087    this.groupName = groupName;
088    this.operName = operatorName;
089    this.driverId = driverId;
090    this.fanOut = fanOut;
091    LOG.config(getQualifiedName() + "Tree Topology running with a fan-out of " + fanOut);
092  }
093
094  @Override
095  @SuppressWarnings("checkstyle:hiddenfield")
096  public void setRootTask(final String rootId) {
097    LOG.entering("TreeTopology", "setRootTask", new Object[]{getQualifiedName(), rootId});
098    this.rootId = rootId;
099    LOG.exiting("TreeTopology", "setRootTask", getQualifiedName() + rootId);
100  }
101
102  @Override
103  public String getRootId() {
104    LOG.entering("TreeTopology", "getRootId", getQualifiedName());
105    LOG.exiting("TreeTopology", "getRootId", getQualifiedName() + rootId);
106    return rootId;
107  }
108
109  @Override
110  public boolean isRootPresent() {
111    LOG.entering("TreeTopology", "isRootPresent", getQualifiedName());
112    final boolean retVal = root != null;
113    LOG.exiting("TreeTopology", "isRootPresent", String.format("%s%s", getQualifiedName(), retVal));
114    return retVal;
115  }
116
117  @Override
118  public void setOperatorSpecification(final OperatorSpec spec) {
119    LOG.entering("TreeTopology", "setOperSpec", new Object[]{getQualifiedName(), spec});
120    this.operatorSpec = spec;
121    LOG.exiting("TreeTopology", "setOperSpec", getQualifiedName() + spec);
122  }
123
124  @Override
125  public Configuration getTaskConfiguration(final String taskId) {
126    LOG.entering("TreeTopology", "getTaskConfig", new Object[]{getQualifiedName(), taskId});
127    final TaskNode taskNode = nodes.get(taskId);
128    if (taskNode == null) {
129      throw new RuntimeException(getQualifiedName() + taskId + " does not exist");
130    }
131
132    final int version = getNodeVersion(taskId);
133    final JavaConfigurationBuilder jcb = Tang.Factory.getTang().newConfigurationBuilder();
134    jcb.bindNamedParameter(DataCodec.class, operatorSpec.getDataCodecClass());
135    jcb.bindNamedParameter(TaskVersion.class, Integer.toString(version));
136    if (operatorSpec instanceof BroadcastOperatorSpec) {
137      final BroadcastOperatorSpec broadcastOperatorSpec = (BroadcastOperatorSpec) operatorSpec;
138      if (taskId.equals(broadcastOperatorSpec.getSenderId())) {
139        jcb.bindImplementation(GroupCommOperator.class, BroadcastSender.class);
140      } else {
141        jcb.bindImplementation(GroupCommOperator.class, BroadcastReceiver.class);
142      }
143    } else if (operatorSpec instanceof ReduceOperatorSpec) {
144      final ReduceOperatorSpec reduceOperatorSpec = (ReduceOperatorSpec) operatorSpec;
145      jcb.bindNamedParameter(ReduceFunctionParam.class, reduceOperatorSpec.getRedFuncClass());
146      if (taskId.equals(reduceOperatorSpec.getReceiverId())) {
147        jcb.bindImplementation(GroupCommOperator.class, ReduceReceiver.class);
148      } else {
149        jcb.bindImplementation(GroupCommOperator.class, ReduceSender.class);
150      }
151    } else if (operatorSpec instanceof ScatterOperatorSpec) {
152      final ScatterOperatorSpec scatterOperatorSpec = (ScatterOperatorSpec) operatorSpec;
153      if (taskId.equals(scatterOperatorSpec.getSenderId())) {
154        jcb.bindImplementation(GroupCommOperator.class, ScatterSender.class);
155      } else {
156        jcb.bindImplementation(GroupCommOperator.class, ScatterReceiver.class);
157      }
158    } else if (operatorSpec instanceof GatherOperatorSpec) {
159      final GatherOperatorSpec gatherOperatorSpec = (GatherOperatorSpec) operatorSpec;
160      if (taskId.equals(gatherOperatorSpec.getReceiverId())) {
161        jcb.bindImplementation(GroupCommOperator.class, GatherReceiver.class);
162      } else {
163        jcb.bindImplementation(GroupCommOperator.class, GatherSender.class);
164      }
165    }
166    final Configuration retConf = jcb.build();
167    LOG.exiting("TreeTopology", "getTaskConfig", getQualifiedName() + confSer.toString(retConf));
168    return retConf;
169  }
170
171  @Override
172  public int getNodeVersion(final String taskId) {
173    LOG.entering("TreeTopology", "getNodeVersion", new Object[]{getQualifiedName(), taskId});
174    final TaskNode node = nodes.get(taskId);
175    if (node == null) {
176      throw new RuntimeException(getQualifiedName() + taskId + " is not available on the nodes map");
177    }
178    final int version = node.getVersion();
179    LOG.exiting("TreeTopology", "getNodeVersion", getQualifiedName() + " " + taskId + " " + version);
180    return version;
181  }
182
183  @Override
184  public void removeTask(final String taskId) {
185    LOG.entering("TreeTopology", "removeTask", new Object[]{getQualifiedName(), taskId});
186    if (!nodes.containsKey(taskId)) {
187      LOG.fine("Trying to remove a non-existent node in the task graph");
188      LOG.exiting("TreeTopology", "removeTask", getQualifiedName());
189      return;
190    }
191    if (taskId.equals(rootId)) {
192      unsetRootNode(taskId);
193    } else {
194      removeChild(taskId);
195    }
196    LOG.exiting("TreeTopology", "removeTask", getQualifiedName() + taskId);
197  }
198
199  @Override
200  public void addTask(final String taskId) {
201    LOG.entering("TreeTopology", "addTask", new Object[]{getQualifiedName(), taskId});
202    if (nodes.containsKey(taskId)) {
203      LOG.fine("Got a request to add a task that is already in the graph. " +
204          "We need to block this request till the delete finishes. ***CAUTION***");
205    }
206
207    if (taskId.equals(rootId)) {
208      setRootNode(taskId);
209    } else {
210      addChild(taskId);
211    }
212    LOG.exiting("TreeTopology", "addTask", getQualifiedName() + taskId);
213  }
214
215  private void addChild(final String taskId) {
216    LOG.entering("TreeTopology", "addChild", new Object[]{getQualifiedName(), taskId});
217    LOG.finest(getQualifiedName() + "Adding leaf " + taskId);
218    final TaskNode node = new TaskNodeImpl(senderStage, groupName, operName, taskId, driverId, false);
219    if (logicalRoot != null) {
220      addTaskNode(node);
221      prev = node;
222    }
223    nodes.put(taskId, node);
224    LOG.exiting("TreeTopology", "addChild", getQualifiedName() + taskId);
225  }
226
227  private void addTaskNode(final TaskNode node) {
228    LOG.entering("TreeTopology", "addTaskNode", new Object[]{getQualifiedName(), node});
229    if (logicalRoot.getNumberOfChildren() >= this.fanOut) {
230      logicalRoot = logicalRoot.successor();
231    }
232    node.setParent(logicalRoot);
233    logicalRoot.addChild(node);
234    prev.setSibling(node);
235    LOG.exiting("TreeTopology", "addTaskNode", getQualifiedName() + node);
236  }
237
238  private void removeChild(final String taskId) {
239    LOG.entering("TreeTopology", "removeChild", new Object[]{getQualifiedName(), taskId});
240    if (root != null) {
241      root.removeChild(nodes.get(taskId));
242    }
243    nodes.remove(taskId);
244    LOG.exiting("TreeTopology", "removeChild", getQualifiedName() + taskId);
245  }
246
247  private void setRootNode(final String newRootId) {
248    LOG.entering("TreeTopology", "setRootNode", new Object[]{getQualifiedName(), newRootId});
249    this.root = new TaskNodeImpl(senderStage, groupName, operName, newRootId, driverId, true);
250    this.logicalRoot = this.root;
251    this.prev = this.root;
252
253    for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) {
254      final TaskNode leaf = nodeEntry.getValue();
255      addTaskNode(leaf);
256      this.prev = leaf;
257    }
258    nodes.put(newRootId, root);
259    LOG.exiting("TreeTopology", "setRootNode", getQualifiedName() + newRootId);
260  }
261
262  private void unsetRootNode(final String taskId) {
263    LOG.entering("TreeTopology", "unsetRootNode", new Object[]{getQualifiedName(), taskId});
264    nodes.remove(rootId);
265    root = null;
266
267    for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) {
268      final TaskNode leaf = nodeEntry.getValue();
269      leaf.setParent(null);
270    }
271    LOG.exiting("TreeTopology", "unsetRootNode", getQualifiedName() + taskId);
272  }
273
274  @Override
275  public void onFailedTask(final String taskId) {
276    LOG.entering("TreeTopology", "onFailedTask", new Object[]{getQualifiedName(), taskId});
277    final TaskNode taskNode = nodes.get(taskId);
278    if (taskNode == null) {
279      throw new RuntimeException(getQualifiedName() + taskId + " does not exist");
280    }
281    taskNode.onFailedTask();
282    LOG.exiting("TreeTopology", "onFailedTask", getQualifiedName() + taskId);
283  }
284
285  @Override
286  public void onRunningTask(final String taskId) {
287    LOG.entering("TreeTopology", "onRunningTask", new Object[]{getQualifiedName(), taskId});
288    final TaskNode taskNode = nodes.get(taskId);
289    if (taskNode == null) {
290      throw new RuntimeException(getQualifiedName() + taskId + " does not exist");
291    }
292    taskNode.onRunningTask();
293    LOG.exiting("TreeTopology", "onRunningTask", getQualifiedName() + taskId);
294  }
295
296  @Override
297  public void onReceiptOfMessage(final GroupCommunicationMessage msg) {
298    LOG.entering("TreeTopology", "onReceiptOfMessage", new Object[]{getQualifiedName(), msg});
299    switch (msg.getType()) {
300    case TopologyChanges:
301      onTopologyChanges(msg);
302      break;
303    case UpdateTopology:
304      onUpdateTopology(msg);
305      break;
306
307    default:
308      nodes.get(msg.getSrcid()).onReceiptOfAcknowledgement(msg);
309      break;
310    }
311    LOG.exiting("TreeTopology", "onReceiptOfMessage", getQualifiedName() + msg);
312  }
313
314  private void onUpdateTopology(final GroupCommunicationMessage msg) {
315    LOG.entering("TreeTopology", "onUpdateTopology", new Object[]{getQualifiedName(), msg});
316    LOG.fine(getQualifiedName() + "Update affected parts of Topology");
317    final String dstId = msg.getSrcid();
318    final int version = getNodeVersion(dstId);
319
320    LOG.finest(getQualifiedName() + "Creating NodeTopologyUpdateWaitStage to wait on nodes to be updated");
321    final EventHandler<List<TaskNode>> topoUpdateWaitHandler = new TopologyUpdateWaitHandler(senderStage, groupName,
322        operName, driverId, 0,
323        dstId, version,
324        getQualifiedName(), TopologySerializer.encode(root));
325    final EStage<List<TaskNode>> nodeTopologyUpdateWaitStage = new SingleThreadStage<>("NodeTopologyUpdateWaitStage",
326        topoUpdateWaitHandler,
327        nodes.size());
328
329    final List<TaskNode> toBeUpdatedNodes = new ArrayList<>(nodes.size());
330    LOG.finest(getQualifiedName() + "Checking which nodes need to be updated");
331    for (final TaskNode node : nodes.values()) {
332      if (node.isRunning() && node.hasChanges() && node.resetTopologySetupSent()) {
333        toBeUpdatedNodes.add(node);
334      }
335    }
336    for (final TaskNode node : toBeUpdatedNodes) {
337      node.updatingTopology();
338      LOG.fine(getQualifiedName() + "Asking " + node + " to UpdateTopology");
339      senderStage.onNext(Utils.bldVersionedGCM(groupName, operName,
340          ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, driverId, 0, node.getTaskId(),
341          node.getVersion(), Utils.EMPTY_BYTE_ARR));
342    }
343    nodeTopologyUpdateWaitStage.onNext(toBeUpdatedNodes);
344    LOG.exiting("TreeTopology", "onUpdateTopology", getQualifiedName() + msg);
345  }
346
347  private void onTopologyChanges(final GroupCommunicationMessage msg) {
348    LOG.entering("TreeTopology", "onTopologyChanges", new Object[]{getQualifiedName(), msg});
349    LOG.fine(getQualifiedName() + "Check TopologyChanges");
350    final String dstId = msg.getSrcid();
351    boolean hasTopologyChanged = false;
352    LOG.finest(getQualifiedName() + "Checking which nodes need to be updated");
353    for (final TaskNode node : nodes.values()) {
354      if (!node.isRunning() || node.hasChanges()) {
355        hasTopologyChanged = true;
356        break;
357      }
358    }
359    final GroupChanges changes = new GroupChangesImpl(hasTopologyChanged);
360    final Codec<GroupChanges> changesCodec = new GroupChangesCodec();
361    LOG.fine(getQualifiedName() + "TopologyChanges: " + changes);
362    senderStage.onNext(Utils.bldVersionedGCM(groupName, operName,
363        ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, driverId, 0, dstId, getNodeVersion(dstId),
364        changesCodec.encode(changes)));
365    LOG.exiting("TreeTopology", "onTopologyChanges", getQualifiedName() + msg);
366  }
367
368  private String getQualifiedName() {
369    return Utils.simpleName(groupName) + ":" + Utils.simpleName(operName) + " - ";
370  }
371}