This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.io.network.group.impl.driver;
020
021import org.apache.reef.driver.parameters.DriverIdentifier;
022import org.apache.reef.io.network.group.api.operators.GroupCommOperator;
023import org.apache.reef.io.network.group.api.GroupChanges;
024import org.apache.reef.io.network.group.api.config.OperatorSpec;
025import org.apache.reef.io.network.group.api.driver.TaskNode;
026import org.apache.reef.io.network.group.api.driver.Topology;
027import org.apache.reef.io.network.group.impl.GroupChangesCodec;
028import org.apache.reef.io.network.group.impl.GroupChangesImpl;
029import org.apache.reef.io.network.group.impl.GroupCommunicationMessage;
030import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
031import org.apache.reef.io.network.group.impl.config.GatherOperatorSpec;
032import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
033import org.apache.reef.io.network.group.impl.config.ScatterOperatorSpec;
034import org.apache.reef.io.network.group.impl.config.parameters.*;
035import org.apache.reef.io.network.group.impl.operators.*;
036import org.apache.reef.io.network.group.impl.utils.Utils;
037import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos;
038import org.apache.reef.io.serialization.Codec;
039import org.apache.reef.tang.Configuration;
040import org.apache.reef.tang.JavaConfigurationBuilder;
041import org.apache.reef.tang.Tang;
042import org.apache.reef.tang.annotations.Name;
043import org.apache.reef.tang.annotations.Parameter;
044import org.apache.reef.wake.EStage;
045import org.apache.reef.wake.EventHandler;
046import org.apache.reef.wake.impl.SingleThreadStage;
047
048import javax.inject.Inject;
049import java.util.ArrayList;
050import java.util.List;
051import java.util.Map;
052import java.util.concurrent.ConcurrentMap;
053import java.util.concurrent.ConcurrentSkipListMap;
054import java.util.logging.Logger;
055
056/**
057 * Implements a one level Tree Topology.
058 */
059public final class FlatTopology implements Topology {
060
061  private static final Logger LOG = Logger.getLogger(FlatTopology.class.getName());
062
063  private final EStage<GroupCommunicationMessage> senderStage;
064  private final Class<? extends Name<String>> groupName;
065  private final Class<? extends Name<String>> operName;
066  private final String driverId;
067  private String rootId;
068  private OperatorSpec operatorSpec;
069
070  private TaskNode root;
071  private final ConcurrentMap<String, TaskNode> nodes = new ConcurrentSkipListMap<>();
072
073  @Inject
074  private FlatTopology(@Parameter(GroupCommSenderStage.class) final EStage<GroupCommunicationMessage> senderStage,
075                       @Parameter(CommGroupNameClass.class) final Class<? extends Name<String>> groupName,
076                       @Parameter(OperatorNameClass.class) final Class<? extends Name<String>> operatorName,
077                       @Parameter(DriverIdentifier.class) final String driverId) {
078    this.senderStage = senderStage;
079    this.groupName = groupName;
080    this.operName = operatorName;
081    this.driverId = driverId;
082  }
083
084  @Override
085  @SuppressWarnings("checkstyle:hiddenfield")
086  public void setRootTask(final String rootId) {
087    this.rootId = rootId;
088  }
089
090  /**
091   * @return the rootId
092   */
093  @Override
094  public String getRootId() {
095    return rootId;
096  }
097
098  @Override
099  public boolean isRootPresent() {
100    return root != null;
101  }
102
103  @Override
104  public void setOperatorSpecification(final OperatorSpec spec) {
105    this.operatorSpec = spec;
106  }
107
108  @Override
109  public Configuration getTaskConfiguration(final String taskId) {
110    LOG.finest(getQualifiedName() + "Getting config for task " + taskId);
111    final TaskNode taskNode = nodes.get(taskId);
112    if (taskNode == null) {
113      throw new RuntimeException(getQualifiedName() + taskId + " does not exist");
114    }
115
116    final int version;
117    version = getNodeVersion(taskId);
118    final JavaConfigurationBuilder jcb = Tang.Factory.getTang().newConfigurationBuilder();
119    jcb.bindNamedParameter(DataCodec.class, operatorSpec.getDataCodecClass());
120    jcb.bindNamedParameter(TaskVersion.class, Integer.toString(version));
121    if (operatorSpec instanceof BroadcastOperatorSpec) {
122      final BroadcastOperatorSpec broadcastOperatorSpec = (BroadcastOperatorSpec) operatorSpec;
123      if (taskId.equals(broadcastOperatorSpec.getSenderId())) {
124        jcb.bindImplementation(GroupCommOperator.class, BroadcastSender.class);
125      } else {
126        jcb.bindImplementation(GroupCommOperator.class, BroadcastReceiver.class);
127      }
128    } else if (operatorSpec instanceof ReduceOperatorSpec) {
129      final ReduceOperatorSpec reduceOperatorSpec = (ReduceOperatorSpec) operatorSpec;
130      jcb.bindNamedParameter(ReduceFunctionParam.class, reduceOperatorSpec.getRedFuncClass());
131      if (taskId.equals(reduceOperatorSpec.getReceiverId())) {
132        jcb.bindImplementation(GroupCommOperator.class, ReduceReceiver.class);
133      } else {
134        jcb.bindImplementation(GroupCommOperator.class, ReduceSender.class);
135      }
136    } else if (operatorSpec instanceof ScatterOperatorSpec) {
137      final ScatterOperatorSpec scatterOperatorSpec = (ScatterOperatorSpec) operatorSpec;
138      if (taskId.equals(scatterOperatorSpec.getSenderId())) {
139        jcb.bindImplementation(GroupCommOperator.class, ScatterSender.class);
140      } else {
141        jcb.bindImplementation(GroupCommOperator.class, ScatterReceiver.class);
142      }
143    } else if (operatorSpec instanceof GatherOperatorSpec) {
144      final GatherOperatorSpec gatherOperatorSpec = (GatherOperatorSpec) operatorSpec;
145      if (taskId.equals(gatherOperatorSpec.getReceiverId())) {
146        jcb.bindImplementation(GroupCommOperator.class, GatherReceiver.class);
147      } else {
148        jcb.bindImplementation(GroupCommOperator.class, GatherSender.class);
149      }
150    }
151    return jcb.build();
152  }
153
154  @Override
155  public int getNodeVersion(final String taskId) {
156    final TaskNode node = nodes.get(taskId);
157    if (node == null) {
158      throw new RuntimeException(getQualifiedName() + taskId + " is not available on the nodes map");
159    }
160    final int version = node.getVersion();
161    return version;
162  }
163
164  @Override
165  public void removeTask(final String taskId) {
166    if (!nodes.containsKey(taskId)) {
167      LOG.warning("Trying to remove a non-existent node in the task graph");
168      return;
169    }
170    if (taskId.equals(rootId)) {
171      unsetRootNode(taskId);
172    } else {
173      removeChild(taskId);
174    }
175  }
176
177  @Override
178  public void addTask(final String taskId) {
179    if (nodes.containsKey(taskId)) {
180      LOG.warning("Got a request to add a task that is already in the graph");
181      LOG.warning("We need to block this request till the delete finishes");
182    }
183    if (taskId.equals(rootId)) {
184      setRootNode(taskId);
185    } else {
186      addChild(taskId);
187    }
188  }
189
190  /**
191   * @param taskId
192   */
193  private void addChild(final String taskId) {
194    LOG.finest(getQualifiedName() + "Adding leaf " + taskId);
195    final TaskNode node = new TaskNodeImpl(senderStage, groupName, operName, taskId, driverId, false);
196    final TaskNode leaf = node;
197    if (root != null) {
198      leaf.setParent(root);
199      root.addChild(leaf);
200    }
201    nodes.put(taskId, leaf);
202  }
203
204  /**
205   * @param taskId
206   */
207  private void removeChild(final String taskId) {
208    LOG.finest(getQualifiedName() + "Removing leaf " + taskId);
209    if (root != null) {
210      root.removeChild(nodes.get(taskId));
211    }
212    nodes.remove(taskId);
213  }
214
215  private void setRootNode(final String newRootId) {
216    LOG.finest(getQualifiedName() + "Setting " + newRootId + " as root");
217    final TaskNode node = new TaskNodeImpl(senderStage, groupName, operName, newRootId, driverId, true);
218    this.root = node;
219
220    for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) {
221      final TaskNode leaf = nodeEntry.getValue();
222      root.addChild(leaf);
223      leaf.setParent(root);
224    }
225    nodes.put(newRootId, root);
226  }
227
228  /**
229   * @param taskId
230   */
231  private void unsetRootNode(final String taskId) {
232    LOG.finest(getQualifiedName() + "Unsetting " + rootId + " as root");
233    nodes.remove(rootId);
234    root = null;
235
236    for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) {
237      final TaskNode leaf = nodeEntry.getValue();
238      leaf.setParent(null);
239    }
240  }
241
242  @Override
243  public void onFailedTask(final String id) {
244    LOG.finest(getQualifiedName() + "Task-" + id + " failed");
245    final TaskNode taskNode = nodes.get(id);
246    if (taskNode == null) {
247      throw new RuntimeException(getQualifiedName() + id + " does not exist");
248    }
249
250    taskNode.onFailedTask();
251  }
252
253  @Override
254  public void onRunningTask(final String id) {
255    LOG.finest(getQualifiedName() + "Task-" + id + " is running");
256    final TaskNode taskNode = nodes.get(id);
257    if (taskNode == null) {
258      throw new RuntimeException(getQualifiedName() + id + " does not exist");
259    }
260
261    taskNode.onRunningTask();
262  }
263
264  @Override
265  public void onReceiptOfMessage(final GroupCommunicationMessage msg) {
266    LOG.finest(getQualifiedName() + "processing " + msg.getType() + " from " + msg.getSrcid());
267    if (msg.getType().equals(ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges)) {
268      processTopologyChanges(msg);
269      return;
270    }
271    if (msg.getType().equals(ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology)) {
272      processUpdateTopology(msg);
273      return;
274    }
275    final String id = msg.getSrcid();
276    nodes.get(id).onReceiptOfAcknowledgement(msg);
277  }
278
279  private void processUpdateTopology(final GroupCommunicationMessage msg) {
280    final String dstId = msg.getSrcid();
281    final int version = getNodeVersion(dstId);
282
283    LOG.finest(getQualifiedName() + "Creating NodeTopologyUpdateWaitStage to wait on nodes to be updated");
284    final EventHandler<List<TaskNode>> topoUpdateWaitHandler = new TopologyUpdateWaitHandler(senderStage, groupName,
285        operName, driverId, 0,
286        dstId, version,
287        getQualifiedName(), TopologySerializer.encode(root));
288    final EStage<List<TaskNode>> nodeTopologyUpdateWaitStage = new SingleThreadStage<>("NodeTopologyUpdateWaitStage",
289        topoUpdateWaitHandler,
290        nodes.size());
291
292    final List<TaskNode> toBeUpdatedNodes = new ArrayList<>(nodes.size());
293    LOG.finest(getQualifiedName() + "Checking which nodes need to be updated");
294    for (final TaskNode node : nodes.values()) {
295      if (node.isRunning() && node.hasChanges() && node.resetTopologySetupSent()) {
296        toBeUpdatedNodes.add(node);
297      }
298    }
299    for (final TaskNode node : toBeUpdatedNodes) {
300      node.updatingTopology();
301      senderStage.onNext(Utils.bldVersionedGCM(groupName, operName,
302          ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, driverId, 0, node.getTaskId(),
303          node.getVersion(), Utils.EMPTY_BYTE_ARR));
304    }
305    nodeTopologyUpdateWaitStage.onNext(toBeUpdatedNodes);
306  }
307
308  private void processTopologyChanges(final GroupCommunicationMessage msg) {
309    final String dstId = msg.getSrcid();
310    boolean hasTopologyChanged = false;
311    LOG.finest(getQualifiedName() + "Checking which nodes need to be updated");
312    for (final TaskNode node : nodes.values()) {
313      if (!node.isRunning() || node.hasChanges()) {
314        hasTopologyChanged = true;
315        break;
316      }
317    }
318    final GroupChanges changes = new GroupChangesImpl(hasTopologyChanged);
319    final Codec<GroupChanges> changesCodec = new GroupChangesCodec();
320    senderStage.onNext(Utils.bldVersionedGCM(groupName, operName,
321        ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, driverId, 0, dstId, getNodeVersion(dstId),
322        changesCodec.encode(changes)));
323  }
324
325  private String getQualifiedName() {
326    return Utils.simpleName(groupName) + ":" + Utils.simpleName(operName) + " - ";
327  }
328}