001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.group.bgd.parameters; 020 021import org.apache.reef.examples.group.bgd.loss.LossFunction; 022import org.apache.reef.tang.Configuration; 023import org.apache.reef.tang.Tang; 024import org.apache.reef.tang.annotations.Parameter; 025import org.apache.reef.tang.formats.CommandLine; 026 027import javax.inject.Inject; 028 029/** 030 * Control parameters for BGD example. 031 */ 032public final class BGDControlParameters { 033 034 private final int dimensions; 035 private final double lambda; 036 private final double eps; 037 private final int iters; 038 private final int minParts; 039 private final boolean rampup; 040 041 private final double eta; 042 private final double probOfSuccessfulIteration; 043 private final BGDLossType lossType; 044 045 @Inject 046 public BGDControlParameters( 047 @Parameter(ModelDimensions.class) final int dimensions, 048 @Parameter(Lambda.class) final double lambda, 049 @Parameter(Eps.class) final double eps, 050 @Parameter(Eta.class) final double eta, 051 @Parameter(ProbabilityOfSuccessfulIteration.class) final double probOfSuccessfulIteration, 052 @Parameter(Iterations.class) final int iters, 053 @Parameter(EnableRampup.class) final boolean rampup, 054 @Parameter(MinParts.class) final int minParts, 055 final BGDLossType lossType) { 056 this.dimensions = dimensions; 057 this.lambda = lambda; 058 this.eps = eps; 059 this.eta = eta; 060 this.probOfSuccessfulIteration = probOfSuccessfulIteration; 061 this.iters = iters; 062 this.rampup = rampup; 063 this.minParts = minParts; 064 this.lossType = lossType; 065 } 066 067 public Configuration getConfiguration() { 068 return Tang.Factory.getTang().newConfigurationBuilder() 069 .bindNamedParameter(ModelDimensions.class, Integer.toString(this.dimensions)) 070 .bindNamedParameter(Lambda.class, Double.toString(this.lambda)) 071 .bindNamedParameter(Eps.class, Double.toString(this.eps)) 072 .bindNamedParameter(Eta.class, Double.toString(this.eta)) 073 .bindNamedParameter(ProbabilityOfSuccessfulIteration.class, Double.toString(probOfSuccessfulIteration)) 074 .bindNamedParameter(Iterations.class, Integer.toString(this.iters)) 075 .bindNamedParameter(EnableRampup.class, Boolean.toString(this.rampup)) 076 .bindNamedParameter(MinParts.class, Integer.toString(this.minParts)) 077 .bindNamedParameter(LossFunctionType.class, lossType.lossFunctionString()) 078 .build(); 079 } 080 081 public static CommandLine registerShortNames(final CommandLine commandLine) { 082 return commandLine 083 .registerShortNameOfClass(ModelDimensions.class) 084 .registerShortNameOfClass(Lambda.class) 085 .registerShortNameOfClass(Eps.class) 086 .registerShortNameOfClass(Eta.class) 087 .registerShortNameOfClass(ProbabilityOfSuccessfulIteration.class) 088 .registerShortNameOfClass(Iterations.class) 089 .registerShortNameOfClass(EnableRampup.class) 090 .registerShortNameOfClass(MinParts.class) 091 .registerShortNameOfClass(LossFunctionType.class); 092 } 093 094 public int getDimensions() { 095 return this.dimensions; 096 } 097 098 public double getLambda() { 099 return this.lambda; 100 } 101 102 public double getEps() { 103 return this.eps; 104 } 105 106 public double getEta() { 107 return this.eta; 108 } 109 110 public double getProbOfSuccessfulIteration() { 111 return probOfSuccessfulIteration; 112 } 113 114 public int getIters() { 115 return this.iters; 116 } 117 118 public int getMinParts() { 119 return this.minParts; 120 } 121 122 public boolean isRampup() { 123 return this.rampup; 124 } 125 126 public Class<? extends LossFunction> getLossFunction() { 127 return this.lossType.getLossFunction(); 128 } 129}