This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.io.network.group.impl.operators;
020
021import org.apache.reef.driver.task.TaskConfigurationOptions;
022import org.apache.reef.exception.evaluator.NetworkException;
023import org.apache.reef.io.network.exception.ParentDeadException;
024import org.apache.reef.io.network.group.api.operators.Gather;
025import org.apache.reef.io.network.group.api.task.CommGroupNetworkHandler;
026import org.apache.reef.io.network.group.api.task.CommunicationGroupServiceClient;
027import org.apache.reef.io.network.group.api.task.OperatorTopology;
028import org.apache.reef.io.network.group.impl.GroupCommunicationMessage;
029import org.apache.reef.io.network.group.impl.config.parameters.*;
030import org.apache.reef.io.network.group.impl.task.OperatorTopologyImpl;
031import org.apache.reef.io.network.group.impl.utils.Utils;
032import org.apache.reef.io.network.impl.NetworkService;
033import org.apache.reef.io.serialization.Codec;
034import org.apache.reef.tang.annotations.Name;
035import org.apache.reef.tang.annotations.Parameter;
036import org.apache.reef.wake.EventHandler;
037import org.apache.reef.wake.Identifier;
038
039import javax.inject.Inject;
040import java.io.ByteArrayInputStream;
041import java.io.DataInputStream;
042import java.io.IOException;
043import java.util.*;
044import java.util.concurrent.atomic.AtomicBoolean;
045import java.util.logging.Level;
046import java.util.logging.Logger;
047
048public class GatherReceiver<T> implements Gather.Receiver<T>, EventHandler<GroupCommunicationMessage> {
049
050  private static final Logger LOG = Logger.getLogger(GatherReceiver.class.getName());
051
052  private final Class<? extends Name<String>> groupName;
053  private final Class<? extends Name<String>> operName;
054  private final Codec<T> dataCodec;
055  private final OperatorTopology topology;
056  private final CommunicationGroupServiceClient commGroupClient;
057  private final AtomicBoolean init = new AtomicBoolean(false);
058  private final int version;
059
060  @Inject
061  public GatherReceiver(@Parameter(CommunicationGroupName.class) final String groupName,
062                        @Parameter(OperatorName.class) final String operName,
063                        @Parameter(TaskConfigurationOptions.Identifier.class) final String selfId,
064                        @Parameter(DataCodec.class) final Codec<T> dataCodec,
065                        @Parameter(DriverIdentifierGroupComm.class) final String driverId,
066                        @Parameter(TaskVersion.class) final int version,
067                        final CommGroupNetworkHandler commGroupNetworkHandler,
068                        final NetworkService<GroupCommunicationMessage> netService,
069                        final CommunicationGroupServiceClient commGroupClient) {
070    LOG.finest(operName + " has CommGroupHandler-" + commGroupNetworkHandler.toString());
071    this.version = version;
072    this.groupName = Utils.getClass(groupName);
073    this.operName = Utils.getClass(operName);
074    this.dataCodec = dataCodec;
075    this.topology = new OperatorTopologyImpl(this.groupName, this.operName,
076                                             selfId, driverId, new Sender(netService), version);
077    this.commGroupClient = commGroupClient;
078    commGroupNetworkHandler.register(this.operName, this);
079  }
080
081  @Override
082  public int getVersion() {
083    return version;
084  }
085
086  @Override
087  public void initialize() throws ParentDeadException {
088    topology.initialize();
089  }
090
091  @Override
092  public Class<? extends Name<String>> getOperName() {
093    return operName;
094  }
095
096  @Override
097  public Class<? extends Name<String>> getGroupName() {
098    return groupName;
099  }
100
101  @Override
102  public String toString() {
103    final StringBuilder sb = new StringBuilder("GatherReceiver:")
104        .append(Utils.simpleName(groupName))
105        .append(":")
106        .append(Utils.simpleName(operName))
107        .append(":")
108        .append(version);
109    return sb.toString();
110  }
111
112  @Override
113  public void onNext(final GroupCommunicationMessage msg) {
114    topology.handle(msg);
115  }
116
117  @Override
118  public List<T> receive() throws NetworkException, InterruptedException {
119    LOG.entering("GatherReceiver", "receive");
120    final Map<String, T> mapOfTaskIdToData = receiveMapOfTaskIdToData();
121
122    LOG.log(Level.FINE, "{0} Sorting data according to lexicographical order of task identifiers.", this);
123    final TreeMap<String, T> sortedMapOfTaskIdToData = new TreeMap<>(mapOfTaskIdToData);
124    final List<T> retList = new LinkedList<>(sortedMapOfTaskIdToData.values());
125
126    LOG.exiting("GatherReceiver", "receive");
127    return retList;
128  }
129
130  @Override
131  public List<T> receive(final List<? extends Identifier> order) throws NetworkException, InterruptedException {
132    LOG.entering("GatherReceiver", "receive");
133    final Map<String, T> mapOfTaskIdToData = receiveMapOfTaskIdToData();
134
135    LOG.log(Level.FINE, "{0} Sorting data according to specified order of task identifiers.", this);
136    final List<T> retList = new LinkedList<>();
137    for (final Identifier key : order) {
138      final String keyString = key.toString();
139      if (mapOfTaskIdToData.containsKey(keyString)) {
140        retList.add(mapOfTaskIdToData.get(key.toString()));
141      } else {
142        LOG.warning(this + " Received no data from " + keyString + ". Adding null.");
143        retList.add(null);
144      }
145    }
146
147    LOG.exiting("GatherReceiver", "receive");
148    return retList;
149  }
150
151  private Map<String, T> receiveMapOfTaskIdToData() {
152    LOG.entering("GatherReceiver", "receiveMapOfTaskIdToData");
153    // I am root.
154    LOG.fine("I am " + this);
155
156    if (init.compareAndSet(false, true)) {
157      LOG.fine(this + " Communication group initializing.");
158      commGroupClient.initialize();
159      LOG.fine(this + " Communication group initialized.");
160    }
161
162    final Map<String, T> mapOfTaskIdToData = new HashMap<>();
163    try {
164      LOG.fine(this + " Waiting for children.");
165      final byte[] gatheredDataFromChildren = topology.recvFromChildren();
166
167      LOG.fine("Using " + dataCodec.getClass().getSimpleName() + " as codec.");
168      try (final ByteArrayInputStream bstream = new ByteArrayInputStream(gatheredDataFromChildren);
169           final DataInputStream dstream = new DataInputStream(bstream)) {
170        while (dstream.available() > 0) {
171          final String identifier = dstream.readUTF();
172          final int dataLength = dstream.readInt();
173          final byte[] data = new byte[dataLength];
174          dstream.readFully(data);
175          mapOfTaskIdToData.put(identifier, dataCodec.decode(data));
176        }
177        LOG.fine(this + " Successfully received gathered data.");
178      }
179
180    } catch (final ParentDeadException e) {
181      throw new RuntimeException("ParentDeadException", e);
182    } catch (final IOException e) {
183      throw new RuntimeException("IOException", e);
184    }
185
186    LOG.exiting("GatherReceiver", "receiveMapOfTaskIdToData");
187    return mapOfTaskIdToData;
188  }
189}