001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.io.network.group.impl.task;
020
021import org.apache.reef.driver.task.TaskConfigurationOptions;
022import org.apache.reef.exception.evaluator.NetworkException;
023import org.apache.reef.io.network.group.api.operators.*;
024import org.apache.reef.io.network.group.impl.config.parameters.DriverIdentifierGroupComm;
025import org.apache.reef.io.network.group.impl.driver.TopologySimpleNode;
026import org.apache.reef.io.network.group.impl.driver.TopologySerializer;
027import org.apache.reef.io.network.impl.NetworkService;
028import org.apache.reef.io.network.group.api.GroupChanges;
029import org.apache.reef.io.network.group.api.task.CommGroupNetworkHandler;
030import org.apache.reef.io.network.group.api.task.CommunicationGroupServiceClient;
031import org.apache.reef.io.network.group.api.task.GroupCommNetworkHandler;
032import org.apache.reef.io.network.group.impl.GroupChangesCodec;
033import org.apache.reef.io.network.group.impl.GroupChangesImpl;
034import org.apache.reef.io.network.group.impl.GroupCommunicationMessage;
035import org.apache.reef.io.network.group.impl.config.parameters.CommunicationGroupName;
036import org.apache.reef.io.network.group.impl.config.parameters.OperatorName;
037import org.apache.reef.io.network.group.impl.config.parameters.SerializedOperConfigs;
038import org.apache.reef.io.network.group.impl.operators.Sender;
039import org.apache.reef.io.network.group.impl.utils.Utils;
040import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos;
041import org.apache.reef.io.network.util.Pair;
042import org.apache.reef.io.serialization.Codec;
043import org.apache.reef.tang.Configuration;
044import org.apache.reef.tang.Injector;
045import org.apache.reef.tang.annotations.Name;
046import org.apache.reef.tang.annotations.Parameter;
047import org.apache.reef.tang.exceptions.InjectionException;
048import org.apache.reef.tang.formats.ConfigurationSerializer;
049import org.apache.reef.wake.EStage;
050import org.apache.reef.wake.Identifier;
051import org.apache.reef.wake.IdentifierFactory;
052import org.apache.reef.wake.impl.ThreadPoolStage;
053
054import javax.inject.Inject;
055import java.io.IOException;
056import java.util.*;
057import java.util.concurrent.CountDownLatch;
058import java.util.concurrent.atomic.AtomicBoolean;
059import java.util.logging.Logger;
060
061public final class CommunicationGroupClientImpl implements CommunicationGroupServiceClient {
062  private static final Logger LOG = Logger.getLogger(CommunicationGroupClientImpl.class.getName());
063
064  private final GroupCommNetworkHandler groupCommNetworkHandler;
065  private final Class<? extends Name<String>> groupName;
066  private final Map<Class<? extends Name<String>>, GroupCommOperator> operators;
067  private final Sender sender;
068
069  private final String taskId;
070  private final boolean isScatterSender;
071  private final IdentifierFactory identifierFactory;
072  private List<Identifier> activeSlaveTasks;
073  private TopologySimpleNode topologySimpleNodeRoot;
074
075  private final String driverId;
076
077  private final CommGroupNetworkHandler commGroupNetworkHandler;
078
079  private final AtomicBoolean init = new AtomicBoolean(false);
080
081  @Inject
082  private CommunicationGroupClientImpl(@Parameter(CommunicationGroupName.class) final String groupName,
083                                      @Parameter(TaskConfigurationOptions.Identifier.class) final String taskId,
084                                      @Parameter(DriverIdentifierGroupComm.class) final String driverId,
085                                      final GroupCommNetworkHandler groupCommNetworkHandler,
086                                      @Parameter(SerializedOperConfigs.class) final Set<String> operatorConfigs,
087                                      final ConfigurationSerializer configSerializer,
088                                      final NetworkService<GroupCommunicationMessage> netService,
089                                      final CommGroupNetworkHandler commGroupNetworkHandler,
090                                      final Injector injector) {
091    this.taskId = taskId;
092    this.driverId = driverId;
093    LOG.finest(groupName + " has GroupCommHandler-" + groupCommNetworkHandler.toString());
094    this.identifierFactory = netService.getIdentifierFactory();
095    this.groupName = Utils.getClass(groupName);
096    this.groupCommNetworkHandler = groupCommNetworkHandler;
097    this.commGroupNetworkHandler = commGroupNetworkHandler;
098    this.sender = new Sender(netService);
099    this.operators = new TreeMap<>(new Comparator<Class<? extends Name<String>>>() {
100
101      @Override
102      public int compare(final Class<? extends Name<String>> o1, final Class<? extends Name<String>> o2) {
103        final String s1 = o1.getSimpleName();
104        final String s2 = o2.getSimpleName();
105        return s1.compareTo(s2);
106      }
107    });
108    try {
109      this.groupCommNetworkHandler.register(this.groupName, commGroupNetworkHandler);
110
111      boolean operatorIsScatterSender = false;
112      for (final String operatorConfigStr : operatorConfigs) {
113
114        final Configuration operatorConfig = configSerializer.fromString(operatorConfigStr);
115        final Injector forkedInjector = injector.forkInjector(operatorConfig);
116
117        forkedInjector.bindVolatileInstance(CommunicationGroupServiceClient.class, this);
118
119        final GroupCommOperator operator = forkedInjector.getInstance(GroupCommOperator.class);
120        final String operName = forkedInjector.getNamedInstance(OperatorName.class);
121        this.operators.put(Utils.getClass(operName), operator);
122        LOG.finest(operName + " has CommGroupHandler-" + commGroupNetworkHandler.toString());
123
124        if (!operatorIsScatterSender && operator instanceof Scatter.Sender) {
125          LOG.fine(operName + " is a scatter sender. Will keep track of active slave tasks.");
126          operatorIsScatterSender = true;
127        }
128      }
129      this.isScatterSender = operatorIsScatterSender;
130    } catch (final InjectionException | IOException e) {
131      throw new RuntimeException("Unable to deserialize operator config", e);
132    }
133  }
134
135  @Override
136  public Broadcast.Sender getBroadcastSender(final Class<? extends Name<String>> operatorName) {
137    LOG.entering("CommunicationGroupClientImpl", "getBroadcastSender", new Object[]{getQualifiedName(),
138        Utils.simpleName(operatorName)});
139    final GroupCommOperator op = operators.get(operatorName);
140    if (!(op instanceof Broadcast.Sender)) {
141      throw new RuntimeException("Configured operator is not a broadcast sender");
142    }
143    commGroupNetworkHandler.addTopologyElement(operatorName);
144    LOG.exiting("CommunicationGroupClientImpl", "getBroadcastSender", getQualifiedName() + op);
145    return (Broadcast.Sender) op;
146  }
147
148  @Override
149  public Reduce.Receiver getReduceReceiver(final Class<? extends Name<String>> operatorName) {
150    LOG.entering("CommunicationGroupClientImpl", "getReduceReceiver", new Object[]{getQualifiedName(),
151        Utils.simpleName(operatorName)});
152    final GroupCommOperator op = operators.get(operatorName);
153    if (!(op instanceof Reduce.Receiver)) {
154      throw new RuntimeException("Configured operator is not a reduce receiver");
155    }
156    commGroupNetworkHandler.addTopologyElement(operatorName);
157    LOG.exiting("CommunicationGroupClientImpl", "getReduceReceiver", getQualifiedName() + op);
158    return (Reduce.Receiver) op;
159  }
160
161  @Override
162  public Scatter.Sender getScatterSender(final Class<? extends Name<String>> operatorName) {
163    LOG.entering("CommunicationGroupClientImpl", "getScatterSender", new Object[]{getQualifiedName(),
164        Utils.simpleName(operatorName)});
165    final GroupCommOperator op = operators.get(operatorName);
166    if (!(op instanceof Scatter.Sender)) {
167      throw new RuntimeException("Configured operator is not a scatter sender");
168    }
169    commGroupNetworkHandler.addTopologyElement(operatorName);
170    LOG.exiting("CommunicationGroupClientImpl", "getScatterSender", getQualifiedName() + op);
171    return (Scatter.Sender) op;
172  }
173
174  @Override
175  public Gather.Receiver getGatherReceiver(final Class<? extends Name<String>> operatorName) {
176    LOG.entering("CommunicationGroupClientImpl", "getGatherReceiver", new Object[]{getQualifiedName(),
177        Utils.simpleName(operatorName)});
178    final GroupCommOperator op = operators.get(operatorName);
179    if (!(op instanceof Gather.Receiver)) {
180      throw new RuntimeException("Configured operator is not a gather receiver");
181    }
182    commGroupNetworkHandler.addTopologyElement(operatorName);
183    LOG.exiting("CommunicationGroupClientImpl", "getGatherReceiver", getQualifiedName() + op);
184    return (Gather.Receiver) op;
185  }
186
187
188  @Override
189  public Broadcast.Receiver getBroadcastReceiver(final Class<? extends Name<String>> operatorName) {
190    LOG.entering("CommunicationGroupClientImpl", "getBroadcastReceiver", new Object[]{getQualifiedName(),
191        Utils.simpleName(operatorName)});
192    final GroupCommOperator op = operators.get(operatorName);
193    if (!(op instanceof Broadcast.Receiver)) {
194      throw new RuntimeException("Configured operator is not a broadcast receiver");
195    }
196    commGroupNetworkHandler.addTopologyElement(operatorName);
197    LOG.exiting("CommunicationGroupClientImpl", "getBroadcastReceiver", getQualifiedName() + op);
198    return (Broadcast.Receiver) op;
199  }
200
201  @Override
202  public Reduce.Sender getReduceSender(final Class<? extends Name<String>> operatorName) {
203    LOG.entering("CommunicationGroupClientImpl", "getReduceSender", new Object[]{getQualifiedName(),
204        Utils.simpleName(operatorName)});
205    final GroupCommOperator op = operators.get(operatorName);
206    if (!(op instanceof Reduce.Sender)) {
207      throw new RuntimeException("Configured operator is not a reduce sender");
208    }
209    commGroupNetworkHandler.addTopologyElement(operatorName);
210    LOG.exiting("CommunicationGroupClientImpl", "getReduceSender", getQualifiedName() + op);
211    return (Reduce.Sender) op;
212  }
213
214  @Override
215  public Scatter.Receiver getScatterReceiver(final Class<? extends Name<String>> operatorName) {
216    LOG.entering("CommunicationGroupClientImpl", "getScatterReceiver", new Object[]{getQualifiedName(),
217        Utils.simpleName(operatorName)});
218    final GroupCommOperator op = operators.get(operatorName);
219    if (!(op instanceof Scatter.Receiver)) {
220      throw new RuntimeException("Configured operator is not a scatter receiver");
221    }
222    commGroupNetworkHandler.addTopologyElement(operatorName);
223    LOG.exiting("CommunicationGroupClientImpl", "getScatterReceiver", getQualifiedName() + op);
224    return (Scatter.Receiver) op;
225  }
226
227  @Override
228  public Gather.Sender getGatherSender(final Class<? extends Name<String>> operatorName) {
229    LOG.entering("CommunicationGroupClientImpl", "getGatherSender", new Object[]{getQualifiedName(),
230        Utils.simpleName(operatorName)});
231    final GroupCommOperator op = operators.get(operatorName);
232    if (!(op instanceof Gather.Sender)) {
233      throw new RuntimeException("Configured operator is not a gather sender");
234    }
235    commGroupNetworkHandler.addTopologyElement(operatorName);
236    LOG.exiting("CommunicationGroupClientImpl", "getGatherSender", getQualifiedName() + op);
237    return (Gather.Sender) op;
238  }
239
240  @Override
241  public void initialize() {
242    LOG.entering("CommunicationGroupClientImpl", "initialize", getQualifiedName());
243    if (init.compareAndSet(false, true)) {
244      LOG.finest("CommGroup-" + groupName + " is initializing");
245      final CountDownLatch initLatch = new CountDownLatch(operators.size());
246
247      final InitHandler initHandler = new InitHandler(initLatch);
248      final EStage<GroupCommOperator> initStage = new ThreadPoolStage<>(initHandler, operators.size());
249      for (final GroupCommOperator op : operators.values()) {
250        initStage.onNext(op);
251      }
252
253      try {
254        initLatch.await();
255      } catch (final InterruptedException e) {
256        throw new RuntimeException("InterruptedException while waiting for initialization", e);
257      }
258
259      if (isScatterSender) {
260        updateTopology();
261      }
262
263      if (initHandler.getException() != null) {
264        throw new RuntimeException(getQualifiedName() + "Parent dead. Current behavior is for the child to die too.");
265      }
266    }
267    LOG.exiting("CommunicationGroupClientImpl", "initialize", getQualifiedName());
268  }
269
270  @Override
271  public GroupChanges getTopologyChanges() {
272    LOG.entering("CommunicationGroupClientImpl", "getTopologyChanges", getQualifiedName());
273    for (final GroupCommOperator op : operators.values()) {
274      final Class<? extends Name<String>> operName = op.getOperName();
275      LOG.finest("Sending TopologyChanges msg to driver");
276      try {
277        sender.send(Utils.bldVersionedGCM(groupName, operName,
278            ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, taskId, op.getVersion(), driverId,
279            0, Utils.EMPTY_BYTE_ARR));
280      } catch (final NetworkException e) {
281        throw new RuntimeException("NetworkException while sending GetTopologyChanges", e);
282      }
283    }
284    final Codec<GroupChanges> changesCodec = new GroupChangesCodec();
285    final Map<Class<? extends Name<String>>, GroupChanges> perOpChanges = new HashMap<>();
286    for (final GroupCommOperator op : operators.values()) {
287      final Class<? extends Name<String>> operName = op.getOperName();
288      final byte[] changes = commGroupNetworkHandler.waitForTopologyChanges(operName);
289      perOpChanges.put(operName, changesCodec.decode(changes));
290    }
291    final GroupChanges retVal = mergeGroupChanges(perOpChanges);
292    LOG.exiting("CommunicationGroupClientImpl", "getTopologyChanges", getQualifiedName() + retVal);
293    return retVal;
294  }
295
296  /**
297   * @param perOpChanges
298   * @return
299   */
300  private GroupChanges mergeGroupChanges(final Map<Class<? extends Name<String>>, GroupChanges> perOpChanges) {
301    LOG.entering("CommunicationGroupClientImpl", "mergeGroupChanges", new Object[]{getQualifiedName(), perOpChanges});
302    boolean doChangesExist = false;
303    for (final GroupChanges change : perOpChanges.values()) {
304      if (change.exist()) {
305        doChangesExist = true;
306        break;
307      }
308    }
309    final GroupChanges changes = new GroupChangesImpl(doChangesExist);
310    LOG.exiting("CommunicationGroupClientImpl", "mergeGroupChanges", getQualifiedName() + changes);
311    return changes;
312  }
313
314  @Override
315  public void updateTopology() {
316    LOG.entering("CommunicationGroupClientImpl", "updateTopology", getQualifiedName());
317    for (final GroupCommOperator op : operators.values()) {
318      final Class<? extends Name<String>> operName = op.getOperName();
319      try {
320        sender.send(Utils.bldVersionedGCM(groupName, operName,
321            ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, taskId, op.getVersion(), driverId,
322            0, Utils.EMPTY_BYTE_ARR));
323      } catch (final NetworkException e) {
324        throw new RuntimeException("NetworkException while sending UpdateTopology", e);
325      }
326    }
327    for (final GroupCommOperator op : operators.values()) {
328      final Class<? extends Name<String>> operName = op.getOperName();
329      GroupCommunicationMessage msg;
330      do {
331        msg = commGroupNetworkHandler.waitForTopologyUpdate(operName);
332      } while (!isMsgVersionOk(msg));
333
334      if (isScatterSender) {
335        updateActiveTasks(msg);
336      }
337    }
338    LOG.exiting("CommunicationGroupClientImpl", "updateTopology", getQualifiedName());
339  }
340
341  private void updateActiveTasks(final GroupCommunicationMessage msg) {
342    LOG.entering("CommunicationGroupClientImpl", "updateActiveTasks", new Object[]{getQualifiedName(), msg});
343
344    final Pair<TopologySimpleNode, List<Identifier>> pair =
345        TopologySerializer.decode(msg.getData()[0], identifierFactory);
346
347    topologySimpleNodeRoot = pair.getFirst();
348
349    activeSlaveTasks = pair.getSecond();
350    // remove myself
351    activeSlaveTasks.remove(identifierFactory.getNewInstance(taskId));
352    // sort the tasks in lexicographical order on task ids
353    Collections.sort(activeSlaveTasks, new Comparator<Identifier>() {
354      @Override
355      public int compare(final Identifier o1, final Identifier o2) {
356        return o1.toString().compareTo(o2.toString());
357      }
358    });
359
360    LOG.exiting("CommunicationGroupClientImpl", "updateActiveTasks", new Object[]{getQualifiedName(), msg});
361  }
362
363  private boolean isMsgVersionOk(final GroupCommunicationMessage msg) {
364    LOG.entering("CommunicationGroupClientImpl", "isMsgVersionOk", new Object[]{getQualifiedName(), msg});
365    if (msg.hasVersion()) {
366      final int msgVersion = msg.getVersion();
367      final GroupCommOperator operator = operators.get(Utils.getClass(msg.getOperatorname()));
368      final int nodeVersion = operator.getVersion();
369      final boolean retVal;
370      if (msgVersion < nodeVersion) {
371        LOG.warning(getQualifiedName() + "Received a ver-" + msgVersion + " msg while expecting ver-" + nodeVersion
372            + ". Discarding msg");
373        retVal = false;
374      } else {
375        retVal = true;
376      }
377      LOG.exiting("CommunicationGroupClientImpl", "isMsgVersionOk",
378          Arrays.toString(new Object[]{retVal, getQualifiedName(), msg}));
379      return retVal;
380    } else {
381      throw new RuntimeException(getQualifiedName() + "can only deal with versioned msgs");
382    }
383  }
384
385  @Override
386  public List<Identifier> getActiveSlaveTasks() {
387    return this.activeSlaveTasks;
388  }
389
390  @Override
391  public TopologySimpleNode getTopologySimpleNodeRoot() {
392    return this.topologySimpleNodeRoot;
393  }
394
395  private String getQualifiedName() {
396    return Utils.simpleName(groupName) + " ";
397  }
398
399  @Override
400  public Class<? extends Name<String>> getName() {
401    return groupName;
402  }
403
404}