001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.io.network.group.impl.utils; 020 021import org.apache.reef.io.network.group.api.task.CommunicationGroupServiceClient; 022import org.apache.reef.io.network.group.impl.driver.TopologySimpleNode; 023import org.apache.reef.io.serialization.Codec; 024import org.apache.reef.wake.Identifier; 025 026import javax.inject.Inject; 027import java.io.ByteArrayOutputStream; 028import java.io.DataOutputStream; 029import java.io.IOException; 030import java.util.HashMap; 031import java.util.List; 032import java.util.Map; 033 034/** 035 * Encode messages for a scatter operation, which can be decoded by {@code ScatterDecoder}. 036 */ 037public final class ScatterEncoder { 038 039 private final CommunicationGroupServiceClient commGroupClient; 040 041 @Inject 042 ScatterEncoder(final CommunicationGroupServiceClient commGroupClient) { 043 this.commGroupClient = commGroupClient; 044 } 045 046 public <T> Map<String, byte[]> encode(final List<T> elements, 047 final List<Integer> counts, 048 final List<? extends Identifier> taskOrder, 049 final Codec<T> dataCodec) { 050 051 // first assign data to all tasks 052 final Map<String, byte[]> taskIdToBytes = encodeAndDistributeElements(elements, counts, taskOrder, dataCodec); 053 // then organize the data so that a node keeps its own data as well as its descendants' data 054 final Map<String, byte[]> childIdToBytes = new HashMap<>(); 055 056 for (final TopologySimpleNode node : commGroupClient.getTopologySimpleNodeRoot().getChildren()) { 057 childIdToBytes.put(node.getTaskId(), encodeScatterMsgForNode(node, taskIdToBytes)); 058 } 059 return childIdToBytes; 060 } 061 062 /** 063 * Compute a single byte array message for a node and its children. 064 * Using {@code taskIdToBytes}, we pack all messages for a 065 * {@code TopologySimpleNode} and its children into a single byte array. 066 * 067 * @param node the target TopologySimpleNode to generate a message for 068 * @param taskIdToBytes map containing byte array of encoded data for individual Tasks 069 * @return single byte array message 070 */ 071 private byte[] encodeScatterMsgForNode(final TopologySimpleNode node, 072 final Map<String, byte[]> taskIdToBytes) { 073 074 try (final ByteArrayOutputStream bstream = new ByteArrayOutputStream(); 075 final DataOutputStream dstream = new DataOutputStream(bstream)) { 076 077 // first write the node's encoded data 078 final String taskId = node.getTaskId(); 079 if (taskIdToBytes.containsKey(taskId)) { 080 dstream.write(taskIdToBytes.get(node.getTaskId())); 081 082 } else { 083 // in case mapOfTaskToBytes does not contain this node's id, write an empty 084 // message (zero elements) 085 dstream.writeInt(0); 086 } 087 088 // and then write its children's identifiers and their encoded data 089 for (final TopologySimpleNode child : node.getChildren()) { 090 dstream.writeUTF(child.getTaskId()); 091 final byte[] childData = encodeScatterMsgForNode(child, taskIdToBytes); 092 dstream.writeInt(childData.length); 093 dstream.write(childData); 094 } 095 096 return bstream.toByteArray(); 097 098 } catch (final IOException e) { 099 throw new RuntimeException("IOException", e); 100 } 101 } 102 103 /** 104 * Encode elements into byte arrays, and distribute them across Tasks indicated by Identifiers. 105 * Note that elements are distributed in the exact order specified in 106 * {@code elements} and not in a round-robin fashion. 107 * For example, (1, 2, 3, 4) uniformly distributed to (task1, task2, task3) would be 108 * {task1: (1, 2), task2: (3), task3: (4)}. 109 * 110 * @param elements list of data elements to encode 111 * @param counts list of numbers specifying how many elements each Task should receive 112 * @param taskOrder list of Identifiers indicating Task Ids 113 * @param codec class for encoding data 114 * @param <T> type of data 115 * @return byte representation of a map of identifiers to encoded data 116 */ 117 private <T> Map<String, byte[]> encodeAndDistributeElements(final List<T> elements, 118 final List<Integer> counts, 119 final List<? extends Identifier> taskOrder, 120 final Codec<T> codec) { 121 final Map<String, byte[]> taskIdToBytes = new HashMap<>(); 122 123 int elementsIndex = 0; 124 for (int taskOrderIndex = 0; taskOrderIndex < taskOrder.size(); taskOrderIndex++) { 125 final int elementCount = counts.get(taskOrderIndex); 126 127 try (final ByteArrayOutputStream bstream = new ByteArrayOutputStream(); 128 final DataOutputStream dstream = new DataOutputStream(bstream)) { 129 130 dstream.writeInt(elementCount); 131 for (final T element : elements.subList(elementsIndex, elementsIndex + elementCount)) { 132 final byte[] encodedElement = codec.encode(element); 133 dstream.writeInt(encodedElement.length); 134 dstream.write(encodedElement); 135 } 136 taskIdToBytes.put(taskOrder.get(taskOrderIndex).toString(), bstream.toByteArray()); 137 138 } catch (final IOException e) { 139 throw new RuntimeException("IOException", e); 140 } 141 142 elementsIndex += elementCount; 143 } 144 145 return taskIdToBytes; 146 } 147}