Skip to content

Commit

Permalink
Merge pull request #968 from Nitish1814/model-104-fix
Browse files Browse the repository at this point in the history
OSS main fix
  • Loading branch information
sonalgoyal authored Nov 25, 2024
2 parents 9757186 + 3ad0908 commit 8446a31
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class Blocker<S,D,R,C,T> {

BlockingTreeUtil<S,D,R,C,T> blockingTreeUtil;

public Blocker(BlockingTreeUtil<S,D,R,C,T> blockingTreeUtilUtil){
public Blocker(BlockingTreeUtil<S,D,R,C,T> blockingTreeUtil){
this.blockingTreeUtil = blockingTreeUtil;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import zingg.common.client.MatchType;
import zingg.common.client.ZFrame;
import zingg.common.client.ZinggClientException;
import zingg.common.client.options.ZinggOption;
import zingg.common.client.util.ColName;
import zingg.common.client.util.ColValues;
import zingg.common.core.context.IContext;
Expand All @@ -20,9 +19,7 @@
public abstract class ZinggBase<S,D, R, C, T> extends ZinggBaseCommon<S, D, R, C, T> {

protected IArguments args;
protected IContext<S,D,R,C,T> context;
protected String name;
protected ZinggOption zinggOption;
protected long startTime;
protected ClientOptions clientOptions;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.commons.logging.LogFactory;

import zingg.common.client.ZinggClientException;
import zingg.common.core.executor.blockingverifier.IVerifyBlockingPipes;
import zingg.common.core.executor.validate.BlockerValidator;
import zingg.common.core.executor.validate.LabellerValidator;
import zingg.common.core.executor.validate.LinkerValidator;
Expand Down Expand Up @@ -44,8 +45,10 @@ public void getBaseExecutors() throws ZinggClientException, IOException{
Trainer<S, D, R, C, T> trainer = getTrainer();
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(trainer,getTrainerValidator(trainer),getConfigFile(),getModelId(),getDFObjectUtil()));

VerifyBlocking<S, D, R, C, T> verifyBlocker = getVerifyBlocker();
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(verifyBlocker, new BlockerValidator<S, D, R, C, T>(verifyBlocker),getConfigFile(),getModelId(),getDFObjectUtil()));
VerifyBlocking<S, D, R, C, T> verifyBlocker = getVerifyBlocker();
IVerifyBlockingPipes<S, D, R, C> verifyBlockingPipes = getVerifyBlockingPipes();
verifyBlockingPipes.setTimestamp(verifyBlocker.getTimestamp());
executorTesterList.add(new ExecutorTester<S, D, R, C, T>(verifyBlocker, new BlockerValidator<S, D, R, C, T>(verifyBlocker, verifyBlockingPipes),getConfigFile(),getModelId(),getDFObjectUtil()));

}

Expand Down Expand Up @@ -77,5 +80,5 @@ public void getAdditionalExecutors() throws ZinggClientException, IOException{

protected abstract Linker<S, D, R, C, T> getLinker() throws ZinggClientException;

protected abstract IVerifyBlockingPipes<S, D, R, C> getVerifyBlockingPipes() throws ZinggClientException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
public class BlockerValidator<S, D, R, C, T> extends ExecutorValidator<S, D, R, C, T> {

public static final Log LOG = LogFactory.getLog(BlockerValidator.class);
IVerifyBlockingPipes verifyBlockingPipeUtil; // = new VerifyBlockingPipes<S,D,R,C>(executor.getContext().getPipeUtil(), ((VerifyBlocking<S, D, R, C, T>) executor).getTimestamp());
IVerifyBlockingPipes<S, D, R, C> verifyBlockingPipes; // = new VerifyBlockingPipes<S,D,R,C>(executor.getContext().getPipeUtil(), ((VerifyBlocking<S, D, R, C, T>) executor).getTimestamp());

public BlockerValidator(VerifyBlocking<S, D, R, C, T> executor) {
public BlockerValidator(VerifyBlocking<S, D, R, C, T> executor, IVerifyBlockingPipes<S, D, R, C> verifyBlockingPipes) {
super(executor);
this.verifyBlockingPipes = verifyBlockingPipes;
}

@Override
public void validateResults() throws ZinggClientException {

ZFrame<D, R, C> df = executor.getContext().getPipeUtil().read(false,false,verifyBlockingPipeUtil.getCountsPipe(executor.getArgs()));
ZFrame<D, R, C> df = executor.getContext().getPipeUtil().read(false,false,verifyBlockingPipes.getCountsPipe(executor.getArgs()));
ZFrame<D, R, C> topDf = df.select(ColName.HASH_COL,ColName.HASH_COUNTS_COL).limit(3);
long blockCount = topDf.count();
LOG.info("blockCount : " + blockCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import zingg.common.core.executor.Labeller;
import zingg.common.core.executor.TestExecutorsSingle;
import zingg.common.core.executor.Trainer;
import zingg.common.core.executor.blockingverifier.IVerifyBlockingPipes;
import zingg.spark.client.util.SparkDFObjectUtil;
import zingg.spark.client.util.SparkModelHelper;
import zingg.spark.client.util.SparkPipeUtil;
import zingg.spark.core.TestSparkBase;
import zingg.common.core.executor.Trainer;
import zingg.spark.core.context.ZinggSparkContext;
Expand Down Expand Up @@ -88,7 +91,12 @@ protected SparkMatcher getMatcher() throws ZinggClientException {
protected SparkLinker getLinker() throws ZinggClientException {
SparkLinker sl = new SparkLinker(ctx);
return sl;
}
}

@Override
protected IVerifyBlockingPipes<SparkSession, Dataset<Row>, Row, Column> getVerifyBlockingPipes() throws ZinggClientException {
return new SparkVerifyBlockingPipes(new SparkPipeUtil(sparkSession), getVerifyBlocker().getTimestamp(), new SparkModelHelper());
}

@Override
protected SparkTrainerValidator getTrainerValidator(Trainer<SparkSession,Dataset<Row>,Row,Column,DataType> trainer) {
Expand Down

0 comments on commit 8446a31

Please sign in to comment.