This project has retired. For details please refer to its Attic page.
Source code
001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.bgd;
020
021import org.apache.reef.examples.group.bgd.operatornames.*;
022import org.apache.reef.examples.group.bgd.parameters.*;
023import org.apache.reef.examples.group.bgd.utils.StepSizes;
024import org.apache.reef.examples.group.utils.math.DenseVector;
025import org.apache.reef.examples.group.utils.math.Vector;
026import org.apache.reef.examples.group.utils.timer.Timer;
027import org.apache.reef.exception.evaluator.NetworkException;
028import org.apache.reef.io.Tuple;
029import org.apache.reef.io.network.group.api.operators.Broadcast;
030import org.apache.reef.io.network.group.api.operators.Reduce;
031import org.apache.reef.io.network.group.api.GroupChanges;
032import org.apache.reef.io.network.group.api.task.CommunicationGroupClient;
033import org.apache.reef.io.network.group.api.task.GroupCommClient;
034import org.apache.reef.io.network.util.Pair;
035import org.apache.reef.io.serialization.Codec;
036import org.apache.reef.io.serialization.SerializableCodec;
037import org.apache.reef.tang.annotations.Parameter;
038import org.apache.reef.task.Task;
039
040import javax.inject.Inject;
041import java.util.ArrayList;
042import java.util.logging.Level;
043import java.util.logging.Logger;
044
045/**
046 * Master task for BGD example.
047 */
048public class MasterTask implements Task {
049
050  public static final String TASK_ID = "MasterTask";
051
052  private static final Logger LOG = Logger.getLogger(MasterTask.class.getName());
053
054  private final CommunicationGroupClient communicationGroupClient;
055  private final Broadcast.Sender<ControlMessages> controlMessageBroadcaster;
056  private final Broadcast.Sender<Vector> modelBroadcaster;
057  private final Reduce.Receiver<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer;
058  private final Broadcast.Sender<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster;
059  private final Broadcast.Sender<Vector> descentDriectionBroadcaster;
060  private final Reduce.Receiver<Pair<Vector, Integer>> lineSearchEvaluationsReducer;
061  private final Broadcast.Sender<Double> minEtaBroadcaster;
062  private final boolean ignoreAndContinue;
063  private final StepSizes ts;
064  private final double lambda;
065  private final int maxIters;
066  private final ArrayList<Double> losses = new ArrayList<>();
067  private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<>();
068  private final Vector model;
069
070  private boolean sendModel = true;
071  private double minEta = 0;
072
073  @Inject
074  public MasterTask(
075      final GroupCommClient groupCommClient,
076      @Parameter(ModelDimensions.class) final int dimensions,
077      @Parameter(Lambda.class) final double lambda,
078      @Parameter(Iterations.class) final int maxIters,
079      @Parameter(EnableRampup.class) final boolean rampup,
080      final StepSizes ts) {
081
082    this.lambda = lambda;
083    this.maxIters = maxIters;
084    this.ts = ts;
085    this.ignoreAndContinue = rampup;
086    this.model = new DenseVector(dimensions);
087    this.communicationGroupClient = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class);
088    this.controlMessageBroadcaster = communicationGroupClient.getBroadcastSender(ControlMessageBroadcaster.class);
089    this.modelBroadcaster = communicationGroupClient.getBroadcastSender(ModelBroadcaster.class);
090    this.lossAndGradientReducer = communicationGroupClient.getReduceReceiver(LossAndGradientReducer.class);
091    this.modelAndDescentDirectionBroadcaster =
092        communicationGroupClient.getBroadcastSender(ModelAndDescentDirectionBroadcaster.class);
093    this.descentDriectionBroadcaster = communicationGroupClient.getBroadcastSender(DescentDirectionBroadcaster.class);
094    this.lineSearchEvaluationsReducer = communicationGroupClient.getReduceReceiver(LineSearchEvaluationsReducer.class);
095    this.minEtaBroadcaster = communicationGroupClient.getBroadcastSender(MinEtaBroadcaster.class);
096  }
097
098  @Override
099  public byte[] call(final byte[] memento) throws Exception {
100
101    double gradientNorm = Double.MAX_VALUE;
102    for (int iteration = 1; !converged(iteration, gradientNorm); ++iteration) {
103      try (final Timer t = new Timer("Current Iteration(" + iteration + ")")) {
104        final Pair<Double, Vector> lossAndGradient = computeLossAndGradient();
105        losses.add(lossAndGradient.getFirst());
106        final Vector descentDirection = getDescentDirection(lossAndGradient.getSecond());
107
108        updateModel(descentDirection);
109
110        gradientNorm = descentDirection.norm2();
111      }
112    }
113    LOG.log(Level.INFO, "OUT: Stop");
114    controlMessageBroadcaster.send(ControlMessages.Stop);
115
116    for (final Double loss : losses) {
117      LOG.log(Level.INFO, "OUT: LOSS = {0}", loss);
118    }
119    return lossCodec.encode(losses);
120  }
121
122  private void updateModel(final Vector descentDirection) throws NetworkException, InterruptedException {
123    try (final Timer t = new Timer("GetDescentDirection + FindMinEta + UpdateModel")) {
124      final Vector lineSearchEvals = lineSearch(descentDirection);
125      minEta = findMinEta(model, descentDirection, lineSearchEvals);
126      model.multAdd(minEta, descentDirection);
127    }
128
129    LOG.log(Level.INFO, "OUT: New Model = {0}", model);
130  }
131
132  private Vector lineSearch(final Vector descentDirection) throws NetworkException, InterruptedException {
133    Vector lineSearchResults = null;
134    boolean allDead = false;
135    do {
136      try (final Timer t = new Timer("LineSearch - Broadcast("
137          + (sendModel ? "ModelAndDescentDirection" : "DescentDirection") + ") + Reduce(LossEvalsInLineSearch)")) {
138        if (sendModel) {
139          LOG.log(Level.INFO, "OUT: DoLineSearchWithModel");
140          controlMessageBroadcaster.send(ControlMessages.DoLineSearchWithModel);
141          modelAndDescentDirectionBroadcaster.send(new Pair<>(model, descentDirection));
142        } else {
143          LOG.log(Level.INFO, "OUT: DoLineSearch");
144          controlMessageBroadcaster.send(ControlMessages.DoLineSearch);
145          descentDriectionBroadcaster.send(descentDirection);
146        }
147        final Pair<Vector, Integer> lineSearchEvals = lineSearchEvaluationsReducer.reduce();
148        if (lineSearchEvals != null) {
149          final int numExamples = lineSearchEvals.getSecond();
150          lineSearchResults = lineSearchEvals.getFirst();
151          lineSearchResults.scale(1.0 / numExamples);
152          LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples);
153          LOG.log(Level.INFO, "OUT: LineSearchEvals: {0}", lineSearchResults);
154          allDead = false;
155        } else {
156          allDead = true;
157        }
158      }
159
160      sendModel = chkAndUpdate();
161    } while (allDead || !ignoreAndContinue && sendModel);
162    return lineSearchResults;
163  }
164
165  private Pair<Double, Vector> computeLossAndGradient() throws NetworkException, InterruptedException {
166    Pair<Double, Vector> returnValue = null;
167    boolean allDead = false;
168    do {
169      try (final Timer t = new Timer("Broadcast(" + (sendModel ? "Model" : "MinEta") + ") + Reduce(LossAndGradient)")) {
170        if (sendModel) {
171          LOG.log(Level.INFO, "OUT: ComputeGradientWithModel");
172          controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithModel);
173          modelBroadcaster.send(model);
174        } else {
175          LOG.log(Level.INFO, "OUT: ComputeGradientWithMinEta");
176          controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithMinEta);
177          minEtaBroadcaster.send(minEta);
178        }
179        final Pair<Pair<Double, Integer>, Vector> lossAndGradient = lossAndGradientReducer.reduce();
180
181        if (lossAndGradient != null) {
182          final int numExamples = lossAndGradient.getFirst().getSecond();
183          LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples);
184          final double lossPerExample = lossAndGradient.getFirst().getFirst() / numExamples;
185          LOG.log(Level.INFO, "OUT: Loss: {0}", lossPerExample);
186          final double objFunc = (lambda / 2) * model.norm2Sqr() + lossPerExample;
187          LOG.log(Level.INFO, "OUT: Objective Func Value: {0}", objFunc);
188          final Vector gradient = lossAndGradient.getSecond();
189          gradient.scale(1.0 / numExamples);
190          LOG.log(Level.INFO, "OUT: Gradient: {0}", gradient);
191          returnValue = new Pair<>(objFunc, gradient);
192          allDead = false;
193        } else {
194          allDead = true;
195        }
196      }
197      sendModel = chkAndUpdate();
198    } while (allDead || !ignoreAndContinue && sendModel);
199    return returnValue;
200  }
201
202  private boolean chkAndUpdate() {
203    long t1 = System.currentTimeMillis();
204    final GroupChanges changes = communicationGroupClient.getTopologyChanges();
205    long t2 = System.currentTimeMillis();
206    LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec");
207    if (changes.exist()) {
208      LOG.log(Level.INFO, "OUT: There exist topology changes. Asking to update Topology");
209      t1 = System.currentTimeMillis();
210      communicationGroupClient.updateTopology();
211      t2 = System.currentTimeMillis();
212      LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec");
213      return true;
214    } else {
215      LOG.log(Level.INFO, "OUT: No changes in topology exist. So not updating topology");
216      return false;
217    }
218  }
219
220  private boolean converged(final int iters, final double gradNorm) {
221    return iters >= maxIters || Math.abs(gradNorm) <= 1e-3;
222  }
223
224  private double findMinEta(final Vector theModel, final Vector descentDir, final Vector lineSearchEvals) {
225    final double wNormSqr = theModel.norm2Sqr();
226    final double dNormSqr = descentDir.norm2Sqr();
227    final double wDotd = theModel.dot(descentDir);
228    final double[] t = ts.getT();
229    int i = 0;
230    for (final double eta : t) {
231      final double modelNormSqr = wNormSqr + (eta * eta) * dNormSqr + 2 * eta * wDotd;
232      final double loss = lineSearchEvals.get(i) + ((lambda / 2) * modelNormSqr);
233      lineSearchEvals.set(i, loss);
234      ++i;
235    }
236    LOG.log(Level.INFO, "OUT: Regularized LineSearchEvals: {0}", lineSearchEvals);
237    final Tuple<Integer, Double> minTup = lineSearchEvals.min();
238    LOG.log(Level.INFO, "OUT: MinTup: {0}", minTup);
239    final double minT = t[minTup.getKey()];
240    LOG.log(Level.INFO, "OUT: MinT: {0}", minT);
241    return minT;
242  }
243
244  private Vector getDescentDirection(final Vector gradient) {
245    gradient.multAdd(lambda, model);
246    gradient.scale(-1);
247    LOG.log(Level.INFO, "OUT: DescentDirection: {0}", gradient);
248    return gradient;
249  }
250}