This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.bgd;
020
021import org.apache.reef.examples.group.bgd.data.Example;
022import org.apache.reef.examples.group.bgd.loss.LossFunction;
023import org.apache.reef.examples.group.bgd.operatornames.*;
024import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
025import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure;
026import org.apache.reef.examples.group.bgd.utils.StepSizes;
027import org.apache.reef.examples.group.utils.math.DenseVector;
028import org.apache.reef.examples.group.utils.math.Vector;
029import org.apache.reef.io.network.group.api.operators.Broadcast;
030import org.apache.reef.io.network.group.api.operators.Reduce;
031import org.apache.reef.io.network.group.api.task.CommunicationGroupClient;
032import org.apache.reef.io.network.group.api.task.GroupCommClient;
033import org.apache.reef.io.network.util.Pair;
034import org.apache.reef.tang.annotations.Parameter;
035import org.apache.reef.task.Task;
036
037import javax.inject.Inject;
038import java.util.List;
039import java.util.logging.Logger;
040
041/**
042 * Slave task for BGD example.
043 */
044public class SlaveTask implements Task {
045
046  private static final Logger LOG = Logger.getLogger(SlaveTask.class.getName());
047
048  private final double failureProb;
049
050  private final CommunicationGroupClient communicationGroup;
051  private final Broadcast.Receiver<ControlMessages> controlMessageBroadcaster;
052  private final Broadcast.Receiver<Vector> modelBroadcaster;
053  private final Reduce.Sender<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer;
054  private final Broadcast.Receiver<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster;
055  private final Broadcast.Receiver<Vector> descentDirectionBroadcaster;
056  private final Reduce.Sender<Pair<Vector, Integer>> lineSearchEvaluationsReducer;
057  private final Broadcast.Receiver<Double> minEtaBroadcaster;
058  private List<Example> examples = null;
059  private final ExampleList dataSet;
060  private final LossFunction lossFunction;
061  private final StepSizes ts;
062
063  private Vector model = null;
064  private Vector descentDirection = null;
065
066  @Inject
067  public SlaveTask(
068      final GroupCommClient groupCommClient,
069      final ExampleList dataSet,
070      final LossFunction lossFunction,
071      @Parameter(ProbabilityOfFailure.class) final double pFailure,
072      final StepSizes ts) {
073
074    this.dataSet = dataSet;
075    this.lossFunction = lossFunction;
076    this.failureProb = pFailure;
077    LOG.info("Using pFailure=" + this.failureProb);
078    this.ts = ts;
079
080    this.communicationGroup = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class);
081    this.controlMessageBroadcaster = communicationGroup.getBroadcastReceiver(ControlMessageBroadcaster.class);
082    this.modelBroadcaster = communicationGroup.getBroadcastReceiver(ModelBroadcaster.class);
083    this.lossAndGradientReducer = communicationGroup.getReduceSender(LossAndGradientReducer.class);
084    this.modelAndDescentDirectionBroadcaster =
085        communicationGroup.getBroadcastReceiver(ModelAndDescentDirectionBroadcaster.class);
086    this.descentDirectionBroadcaster = communicationGroup.getBroadcastReceiver(DescentDirectionBroadcaster.class);
087    this.lineSearchEvaluationsReducer = communicationGroup.getReduceSender(LineSearchEvaluationsReducer.class);
088    this.minEtaBroadcaster = communicationGroup.getBroadcastReceiver(MinEtaBroadcaster.class);
089  }
090
091  @Override
092  public byte[] call(final byte[] memento) throws Exception {
093    /*
094     * In the case where there will be evaluator failure and data is not in
095     * memory we want to load the data while waiting to join the communication
096     * group
097     */
098    loadData();
099
100    for (boolean repeat = true; repeat;) {
101
102      final ControlMessages controlMessage = controlMessageBroadcaster.receive();
103      switch (controlMessage) {
104
105      case Stop:
106        repeat = false;
107        break;
108
109      case ComputeGradientWithModel:
110        failPerhaps();
111        this.model = modelBroadcaster.receive();
112        lossAndGradientReducer.send(computeLossAndGradient());
113        break;
114
115      case ComputeGradientWithMinEta:
116        failPerhaps();
117        final double minEta = minEtaBroadcaster.receive();
118        assert descentDirection != null;
119        this.descentDirection.scale(minEta);
120        assert model != null;
121        this.model.add(descentDirection);
122        lossAndGradientReducer.send(computeLossAndGradient());
123        break;
124
125      case DoLineSearch:
126        failPerhaps();
127        this.descentDirection = descentDirectionBroadcaster.receive();
128        lineSearchEvaluationsReducer.send(lineSearchEvals());
129        break;
130
131      case DoLineSearchWithModel:
132        failPerhaps();
133        final Pair<Vector, Vector> modelAndDescentDir = modelAndDescentDirectionBroadcaster.receive();
134        this.model = modelAndDescentDir.getFirst();
135        this.descentDirection = modelAndDescentDir.getSecond();
136        lineSearchEvaluationsReducer.send(lineSearchEvals());
137        break;
138
139      default:
140        break;
141      }
142    }
143
144    return null;
145  }
146
147  private void failPerhaps() {
148    if (Math.random() < failureProb) {
149      throw new RuntimeException("Simulated Failure");
150    }
151  }
152
153  private Pair<Vector, Integer> lineSearchEvals() {
154
155    if (examples == null) {
156      loadData();
157    }
158
159    final Vector zed = new DenseVector(examples.size());
160    final Vector ee = new DenseVector(examples.size());
161
162    for (int i = 0; i < examples.size(); i++) {
163      final Example example = examples.get(i);
164      double f = example.predict(model);
165      zed.set(i, f);
166      f = example.predict(descentDirection);
167      ee.set(i, f);
168    }
169
170    final double[] t = ts.getT();
171    final Vector evaluations = new DenseVector(t.length);
172    int i = 0;
173    for (final double d : t) {
174      double loss = 0;
175      for (int j = 0; j < examples.size(); j++) {
176        final Example example = examples.get(j);
177        final double val = zed.get(j) + d * ee.get(j);
178        loss += this.lossFunction.computeLoss(example.getLabel(), val);
179      }
180      evaluations.set(i++, loss);
181    }
182
183    return new Pair<>(evaluations, examples.size());
184  }
185
186  private Pair<Pair<Double, Integer>, Vector> computeLossAndGradient() {
187
188    if (examples == null) {
189      loadData();
190    }
191
192    final Vector gradient = new DenseVector(model.size());
193    double loss = 0.0;
194    for (final Example example : examples) {
195      final double f = example.predict(model);
196      final double g = this.lossFunction.computeGradient(example.getLabel(), f);
197      example.addGradient(gradient, g);
198      loss += this.lossFunction.computeLoss(example.getLabel(), f);
199    }
200
201    return new Pair<>(new Pair<>(loss, examples.size()), gradient);
202  }
203
204  private void loadData() {
205    LOG.info("Loading data");
206    examples = dataSet.getExamples();
207  }
208}