This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.io.network.group.impl.utils;
020
021import org.apache.reef.io.network.group.api.task.CommunicationGroupServiceClient;
022import org.apache.reef.io.network.group.impl.driver.TopologySimpleNode;
023import org.apache.reef.io.serialization.Codec;
024import org.apache.reef.wake.Identifier;
025
026import javax.inject.Inject;
027import java.io.ByteArrayOutputStream;
028import java.io.DataOutputStream;
029import java.io.IOException;
030import java.util.HashMap;
031import java.util.List;
032import java.util.Map;
033
034/**
035 * Encode messages for a scatter operation, which can be decoded by {@code ScatterDecoder}.
036 */
037public final class ScatterEncoder {
038
039  private final CommunicationGroupServiceClient commGroupClient;
040
041  @Inject
042  ScatterEncoder(final CommunicationGroupServiceClient commGroupClient) {
043    this.commGroupClient = commGroupClient;
044  }
045
046  public <T> Map<String, byte[]> encode(final List<T> elements,
047                                        final List<Integer> counts,
048                                        final List<? extends Identifier> taskOrder,
049                                        final Codec<T> dataCodec) {
050
051    // first assign data to all tasks
052    final Map<String, byte[]> taskIdToBytes = encodeAndDistributeElements(elements, counts, taskOrder, dataCodec);
053    // then organize the data so that a node keeps its own data as well as its descendants' data
054    final Map<String, byte[]> childIdToBytes = new HashMap<>();
055
056    for (final TopologySimpleNode node : commGroupClient.getTopologySimpleNodeRoot().getChildren()) {
057      childIdToBytes.put(node.getTaskId(), encodeScatterMsgForNode(node, taskIdToBytes));
058    }
059    return childIdToBytes;
060  }
061
062  /**
063   * Compute a single byte array message for a node and its children.
064   * Using {@code taskIdToBytes}, we pack all messages for a
065   * {@code TopologySimpleNode} and its children into a single byte array.
066   *
067   * @param node the target TopologySimpleNode to generate a message for
068   * @param taskIdToBytes map containing byte array of encoded data for individual Tasks
069   * @return single byte array message
070   */
071  private byte[] encodeScatterMsgForNode(final TopologySimpleNode node,
072                                         final Map<String, byte[]> taskIdToBytes) {
073
074    try (final ByteArrayOutputStream bstream = new ByteArrayOutputStream();
075         final DataOutputStream dstream = new DataOutputStream(bstream)) {
076
077      // first write the node's encoded data
078      final String taskId = node.getTaskId();
079      if (taskIdToBytes.containsKey(taskId)) {
080        dstream.write(taskIdToBytes.get(node.getTaskId()));
081
082      } else {
083        // in case mapOfTaskToBytes does not contain this node's id, write an empty
084        // message (zero elements)
085        dstream.writeInt(0);
086      }
087
088      // and then write its children's identifiers and their encoded data
089      for (final TopologySimpleNode child : node.getChildren()) {
090        dstream.writeUTF(child.getTaskId());
091        final byte[] childData = encodeScatterMsgForNode(child, taskIdToBytes);
092        dstream.writeInt(childData.length);
093        dstream.write(childData);
094      }
095
096      return bstream.toByteArray();
097
098    } catch (final IOException e) {
099      throw new RuntimeException("IOException", e);
100    }
101  }
102
103  /**
104   * Encode elements into byte arrays, and distribute them across Tasks indicated by Identifiers.
105   * Note that elements are distributed in the exact order specified in
106   * {@code elements} and not in a round-robin fashion.
107   * For example, (1, 2, 3, 4) uniformly distributed to (task1, task2, task3) would be
108   * {task1: (1, 2), task2: (3), task3: (4)}.
109   *
110   * @param elements list of data elements to encode
111   * @param counts list of numbers specifying how many elements each Task should receive
112   * @param taskOrder list of Identifiers indicating Task Ids
113   * @param codec class for encoding data
114   * @param <T> type of data
115   * @return byte representation of a map of identifiers to encoded data
116   */
117  private <T> Map<String, byte[]> encodeAndDistributeElements(final List<T> elements,
118                                                              final List<Integer> counts,
119                                                              final List<? extends Identifier> taskOrder,
120                                                              final Codec<T> codec) {
121    final Map<String, byte[]> taskIdToBytes = new HashMap<>();
122
123    int elementsIndex = 0;
124    for (int taskOrderIndex = 0; taskOrderIndex < taskOrder.size(); taskOrderIndex++) {
125      final int elementCount = counts.get(taskOrderIndex);
126
127      try (final ByteArrayOutputStream bstream = new ByteArrayOutputStream();
128           final DataOutputStream dstream = new DataOutputStream(bstream)) {
129
130        dstream.writeInt(elementCount);
131        for (final T element : elements.subList(elementsIndex, elementsIndex + elementCount)) {
132          final byte[] encodedElement = codec.encode(element);
133          dstream.writeInt(encodedElement.length);
134          dstream.write(encodedElement);
135        }
136        taskIdToBytes.put(taskOrder.get(taskOrderIndex).toString(), bstream.toByteArray());
137
138      } catch (final IOException e) {
139        throw new RuntimeException("IOException",  e);
140      }
141
142      elementsIndex += elementCount;
143    }
144
145    return taskIdToBytes;
146  }
147}