001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.bgd.parameters;
020
021import org.apache.reef.examples.group.bgd.loss.LossFunction;
022import org.apache.reef.tang.Configuration;
023import org.apache.reef.tang.Tang;
024import org.apache.reef.tang.annotations.Parameter;
025import org.apache.reef.tang.formats.CommandLine;
026
027import javax.inject.Inject;
028
029/**
030 * Control parameters for BGD example.
031 */
032public final class BGDControlParameters {
033
034  private final int dimensions;
035  private final double lambda;
036  private final double eps;
037  private final int iters;
038  private final int minParts;
039  private final boolean rampup;
040
041  private final double eta;
042  private final double probOfSuccessfulIteration;
043  private final BGDLossType lossType;
044
045  @Inject
046  public BGDControlParameters(
047      @Parameter(ModelDimensions.class) final int dimensions,
048      @Parameter(Lambda.class) final double lambda,
049      @Parameter(Eps.class) final double eps,
050      @Parameter(Eta.class) final double eta,
051      @Parameter(ProbabilityOfSuccessfulIteration.class) final double probOfSuccessfulIteration,
052      @Parameter(Iterations.class) final int iters,
053      @Parameter(EnableRampup.class) final boolean rampup,
054      @Parameter(MinParts.class) final int minParts,
055      final BGDLossType lossType) {
056    this.dimensions = dimensions;
057    this.lambda = lambda;
058    this.eps = eps;
059    this.eta = eta;
060    this.probOfSuccessfulIteration = probOfSuccessfulIteration;
061    this.iters = iters;
062    this.rampup = rampup;
063    this.minParts = minParts;
064    this.lossType = lossType;
065  }
066
067  public Configuration getConfiguration() {
068    return Tang.Factory.getTang().newConfigurationBuilder()
069        .bindNamedParameter(ModelDimensions.class, Integer.toString(this.dimensions))
070        .bindNamedParameter(Lambda.class, Double.toString(this.lambda))
071        .bindNamedParameter(Eps.class, Double.toString(this.eps))
072        .bindNamedParameter(Eta.class, Double.toString(this.eta))
073        .bindNamedParameter(ProbabilityOfSuccessfulIteration.class, Double.toString(probOfSuccessfulIteration))
074        .bindNamedParameter(Iterations.class, Integer.toString(this.iters))
075        .bindNamedParameter(EnableRampup.class, Boolean.toString(this.rampup))
076        .bindNamedParameter(MinParts.class, Integer.toString(this.minParts))
077        .bindNamedParameter(LossFunctionType.class, lossType.lossFunctionString())
078        .build();
079  }
080
081  public static CommandLine registerShortNames(final CommandLine commandLine) {
082    return commandLine
083        .registerShortNameOfClass(ModelDimensions.class)
084        .registerShortNameOfClass(Lambda.class)
085        .registerShortNameOfClass(Eps.class)
086        .registerShortNameOfClass(Eta.class)
087        .registerShortNameOfClass(ProbabilityOfSuccessfulIteration.class)
088        .registerShortNameOfClass(Iterations.class)
089        .registerShortNameOfClass(EnableRampup.class)
090        .registerShortNameOfClass(MinParts.class)
091        .registerShortNameOfClass(LossFunctionType.class);
092  }
093
094  public int getDimensions() {
095    return this.dimensions;
096  }
097
098  public double getLambda() {
099    return this.lambda;
100  }
101
102  public double getEps() {
103    return this.eps;
104  }
105
106  public double getEta() {
107    return this.eta;
108  }
109
110  public double getProbOfSuccessfulIteration() {
111    return probOfSuccessfulIteration;
112  }
113
114  public int getIters() {
115    return this.iters;
116  }
117
118  public int getMinParts() {
119    return this.minParts;
120  }
121
122  public boolean isRampup() {
123    return this.rampup;
124  }
125
126  public Class<? extends LossFunction> getLossFunction() {
127    return this.lossType.getLossFunction();
128  }
129}