This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.io.network.group.impl.driver;
020
021import org.apache.reef.driver.parameters.DriverIdentifier;
022import org.apache.reef.io.network.group.api.operators.GroupCommOperator;
023import org.apache.reef.io.network.group.api.GroupChanges;
024import org.apache.reef.io.network.group.api.config.OperatorSpec;
025import org.apache.reef.io.network.group.api.driver.TaskNode;
026import org.apache.reef.io.network.group.api.driver.Topology;
027import org.apache.reef.io.network.group.impl.GroupChangesCodec;
028import org.apache.reef.io.network.group.impl.GroupChangesImpl;
029import org.apache.reef.io.network.group.impl.GroupCommunicationMessage;
030import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
031import org.apache.reef.io.network.group.impl.config.GatherOperatorSpec;
032import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
033import org.apache.reef.io.network.group.impl.config.ScatterOperatorSpec;
034import org.apache.reef.io.network.group.impl.config.parameters.*;
035import org.apache.reef.io.network.group.impl.operators.*;
036import org.apache.reef.io.network.group.impl.utils.Utils;
037import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos;
038import org.apache.reef.io.serialization.Codec;
039import org.apache.reef.tang.Configuration;
040import org.apache.reef.tang.JavaConfigurationBuilder;
041import org.apache.reef.tang.Tang;
042import org.apache.reef.tang.annotations.Name;
043import org.apache.reef.tang.annotations.Parameter;
044import org.apache.reef.wake.EStage;
045import org.apache.reef.wake.EventHandler;
046import org.apache.reef.wake.impl.SingleThreadStage;
047
048import javax.inject.Inject;
049import java.util.ArrayList;
050import java.util.List;
051import java.util.Map;
052import java.util.concurrent.ConcurrentMap;
053import java.util.concurrent.ConcurrentSkipListMap;
054import java.util.logging.Logger;
055
056/**
057 * Implements a one level Tree Topology.
058 */
059public class FlatTopology implements Topology {
060
061  private static final Logger LOG = Logger.getLogger(FlatTopology.class.getName());
062
063  private final EStage<GroupCommunicationMessage> senderStage;
064  private final Class<? extends Name<String>> groupName;
065  private final Class<? extends Name<String>> operName;
066  private final String driverId;
067  private String rootId;
068  private OperatorSpec operatorSpec;
069
070  private TaskNode root;
071  private final ConcurrentMap<String, TaskNode> nodes = new ConcurrentSkipListMap<>();
072
073  /**
074   * @deprecated in 0.14. Use Tang to obtain an instance of this instead.
075   */
076  @Deprecated
077  public FlatTopology(final EStage<GroupCommunicationMessage> senderStage,
078                      final Class<? extends Name<String>> groupName,
079                      final Class<? extends Name<String>> operatorName,
080                      final String driverId, final int numberOfTasks) {
081    this.senderStage = senderStage;
082    this.groupName = groupName;
083    this.operName = operatorName;
084    this.driverId = driverId;
085  }
086
087  @Inject
088  private FlatTopology(@Parameter(GroupCommSenderStage.class) final EStage<GroupCommunicationMessage> senderStage,
089                       @Parameter(CommGroupNameClass.class) final Class<? extends Name<String>> groupName,
090                       @Parameter(OperatorNameClass.class) final Class<? extends Name<String>> operatorName,
091                       @Parameter(DriverIdentifier.class) final String driverId) {
092    this.senderStage = senderStage;
093    this.groupName = groupName;
094    this.operName = operatorName;
095    this.driverId = driverId;
096  }
097
098  @Override
099  @SuppressWarnings("checkstyle:hiddenfield")
100  public void setRootTask(final String rootId) {
101    this.rootId = rootId;
102  }
103
104  /**
105   * @return the rootId
106   */
107  @Override
108  public String getRootId() {
109    return rootId;
110  }
111
112  @Override
113  public boolean isRootPresent() {
114    return root != null;
115  }
116
117  @Override
118  public void setOperatorSpecification(final OperatorSpec spec) {
119    this.operatorSpec = spec;
120  }
121
122  @Override
123  public Configuration getTaskConfiguration(final String taskId) {
124    LOG.finest(getQualifiedName() + "Getting config for task " + taskId);
125    final TaskNode taskNode = nodes.get(taskId);
126    if (taskNode == null) {
127      throw new RuntimeException(getQualifiedName() + taskId + " does not exist");
128    }
129
130    final int version;
131    version = getNodeVersion(taskId);
132    final JavaConfigurationBuilder jcb = Tang.Factory.getTang().newConfigurationBuilder();
133    jcb.bindNamedParameter(DataCodec.class, operatorSpec.getDataCodecClass());
134    jcb.bindNamedParameter(TaskVersion.class, Integer.toString(version));
135    if (operatorSpec instanceof BroadcastOperatorSpec) {
136      final BroadcastOperatorSpec broadcastOperatorSpec = (BroadcastOperatorSpec) operatorSpec;
137      if (taskId.equals(broadcastOperatorSpec.getSenderId())) {
138        jcb.bindImplementation(GroupCommOperator.class, BroadcastSender.class);
139      } else {
140        jcb.bindImplementation(GroupCommOperator.class, BroadcastReceiver.class);
141      }
142    } else if (operatorSpec instanceof ReduceOperatorSpec) {
143      final ReduceOperatorSpec reduceOperatorSpec = (ReduceOperatorSpec) operatorSpec;
144      jcb.bindNamedParameter(ReduceFunctionParam.class, reduceOperatorSpec.getRedFuncClass());
145      if (taskId.equals(reduceOperatorSpec.getReceiverId())) {
146        jcb.bindImplementation(GroupCommOperator.class, ReduceReceiver.class);
147      } else {
148        jcb.bindImplementation(GroupCommOperator.class, ReduceSender.class);
149      }
150    } else if (operatorSpec instanceof ScatterOperatorSpec) {
151      final ScatterOperatorSpec scatterOperatorSpec = (ScatterOperatorSpec) operatorSpec;
152      if (taskId.equals(scatterOperatorSpec.getSenderId())) {
153        jcb.bindImplementation(GroupCommOperator.class, ScatterSender.class);
154      } else {
155        jcb.bindImplementation(GroupCommOperator.class, ScatterReceiver.class);
156      }
157    } else if (operatorSpec instanceof GatherOperatorSpec) {
158      final GatherOperatorSpec gatherOperatorSpec = (GatherOperatorSpec) operatorSpec;
159      if (taskId.equals(gatherOperatorSpec.getReceiverId())) {
160        jcb.bindImplementation(GroupCommOperator.class, GatherReceiver.class);
161      } else {
162        jcb.bindImplementation(GroupCommOperator.class, GatherSender.class);
163      }
164    }
165    return jcb.build();
166  }
167
168  @Override
169  public int getNodeVersion(final String taskId) {
170    final TaskNode node = nodes.get(taskId);
171    if (node == null) {
172      throw new RuntimeException(getQualifiedName() + taskId + " is not available on the nodes map");
173    }
174    final int version = node.getVersion();
175    return version;
176  }
177
178  @Override
179  public void removeTask(final String taskId) {
180    if (!nodes.containsKey(taskId)) {
181      LOG.warning("Trying to remove a non-existent node in the task graph");
182      return;
183    }
184    if (taskId.equals(rootId)) {
185      unsetRootNode(taskId);
186    } else {
187      removeChild(taskId);
188    }
189  }
190
191  @Override
192  public void addTask(final String taskId) {
193    if (nodes.containsKey(taskId)) {
194      LOG.warning("Got a request to add a task that is already in the graph");
195      LOG.warning("We need to block this request till the delete finishes");
196    }
197    if (taskId.equals(rootId)) {
198      setRootNode(taskId);
199    } else {
200      addChild(taskId);
201    }
202  }
203
204  /**
205   * @param taskId
206   */
207  private void addChild(final String taskId) {
208    LOG.finest(getQualifiedName() + "Adding leaf " + taskId);
209    final TaskNode node = new TaskNodeImpl(senderStage, groupName, operName, taskId, driverId, false);
210    final TaskNode leaf = node;
211    if (root != null) {
212      leaf.setParent(root);
213      root.addChild(leaf);
214    }
215    nodes.put(taskId, leaf);
216  }
217
218  /**
219   * @param taskId
220   */
221  private void removeChild(final String taskId) {
222    LOG.finest(getQualifiedName() + "Removing leaf " + taskId);
223    if (root != null) {
224      root.removeChild(nodes.get(taskId));
225    }
226    nodes.remove(taskId);
227  }
228
229  private void setRootNode(final String newRootId) {
230    LOG.finest(getQualifiedName() + "Setting " + newRootId + " as root");
231    final TaskNode node = new TaskNodeImpl(senderStage, groupName, operName, newRootId, driverId, true);
232    this.root = node;
233
234    for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) {
235      final TaskNode leaf = nodeEntry.getValue();
236      root.addChild(leaf);
237      leaf.setParent(root);
238    }
239    nodes.put(newRootId, root);
240  }
241
242  /**
243   * @param taskId
244   */
245  private void unsetRootNode(final String taskId) {
246    LOG.finest(getQualifiedName() + "Unsetting " + rootId + " as root");
247    nodes.remove(rootId);
248    root = null;
249
250    for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) {
251      final TaskNode leaf = nodeEntry.getValue();
252      leaf.setParent(null);
253    }
254  }
255
256  @Override
257  public void onFailedTask(final String id) {
258    LOG.finest(getQualifiedName() + "Task-" + id + " failed");
259    final TaskNode taskNode = nodes.get(id);
260    if (taskNode == null) {
261      throw new RuntimeException(getQualifiedName() + id + " does not exist");
262    }
263
264    taskNode.onFailedTask();
265  }
266
267  @Override
268  public void onRunningTask(final String id) {
269    LOG.finest(getQualifiedName() + "Task-" + id + " is running");
270    final TaskNode taskNode = nodes.get(id);
271    if (taskNode == null) {
272      throw new RuntimeException(getQualifiedName() + id + " does not exist");
273    }
274
275    taskNode.onRunningTask();
276  }
277
278  @Override
279  public void onReceiptOfMessage(final GroupCommunicationMessage msg) {
280    LOG.finest(getQualifiedName() + "processing " + msg.getType() + " from " + msg.getSrcid());
281    if (msg.getType().equals(ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges)) {
282      processTopologyChanges(msg);
283      return;
284    }
285    if (msg.getType().equals(ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology)) {
286      processUpdateTopology(msg);
287      return;
288    }
289    final String id = msg.getSrcid();
290    nodes.get(id).onReceiptOfAcknowledgement(msg);
291  }
292
293  private void processUpdateTopology(final GroupCommunicationMessage msg) {
294    final String dstId = msg.getSrcid();
295    final int version = getNodeVersion(dstId);
296
297    LOG.finest(getQualifiedName() + "Creating NodeTopologyUpdateWaitStage to wait on nodes to be updated");
298    final EventHandler<List<TaskNode>> topoUpdateWaitHandler = new TopologyUpdateWaitHandler(senderStage, groupName,
299        operName, driverId, 0,
300        dstId, version,
301        getQualifiedName(), TopologySerializer.encode(root));
302    final EStage<List<TaskNode>> nodeTopologyUpdateWaitStage = new SingleThreadStage<>("NodeTopologyUpdateWaitStage",
303        topoUpdateWaitHandler,
304        nodes.size());
305
306    final List<TaskNode> toBeUpdatedNodes = new ArrayList<>(nodes.size());
307    LOG.finest(getQualifiedName() + "Checking which nodes need to be updated");
308    for (final TaskNode node : nodes.values()) {
309      if (node.isRunning() && node.hasChanges() && node.resetTopologySetupSent()) {
310        toBeUpdatedNodes.add(node);
311      }
312    }
313    for (final TaskNode node : toBeUpdatedNodes) {
314      node.updatingTopology();
315      senderStage.onNext(Utils.bldVersionedGCM(groupName, operName,
316          ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, driverId, 0, node.getTaskId(),
317          node.getVersion(), Utils.EMPTY_BYTE_ARR));
318    }
319    nodeTopologyUpdateWaitStage.onNext(toBeUpdatedNodes);
320  }
321
322  private void processTopologyChanges(final GroupCommunicationMessage msg) {
323    final String dstId = msg.getSrcid();
324    boolean hasTopologyChanged = false;
325    LOG.finest(getQualifiedName() + "Checking which nodes need to be updated");
326    for (final TaskNode node : nodes.values()) {
327      if (!node.isRunning() || node.hasChanges()) {
328        hasTopologyChanged = true;
329        break;
330      }
331    }
332    final GroupChanges changes = new GroupChangesImpl(hasTopologyChanged);
333    final Codec<GroupChanges> changesCodec = new GroupChangesCodec();
334    senderStage.onNext(Utils.bldVersionedGCM(groupName, operName,
335        ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, driverId, 0, dstId, getNodeVersion(dstId),
336        changesCodec.encode(changes)));
337  }
338
339  private String getQualifiedName() {
340    return Utils.simpleName(groupName) + ":" + Utils.simpleName(operName) + " - ";
341  }
342}