User 2124 | 2/29/2016, 11:36:39 AM
We are using Graphlab commercial version to train our BoostedTreesRegression model on a batch server. We then export the JSON trees of the trained model (using BoostedTreesRegression.get( 'treesjson' ) method) to a text file and move it to another server (say, the Web server) that we can only run Java/Scala (not Python or C++). To use this model on the Web server, we must re-implement the BoostedTreesRegression.predict method in Scala. The implementation simply reads the JSON trees, parses them and then searches these trees for the leaves corresponded to the input data. Finally we sum up all the values at the leaves and add 0.5 to get the final predicted value. The above procedure works well for about 80% of the test data (it yielded nearly identical results between our own implementation of the 'predict' method (in Scala) and the Graphlab original 'predict' method in Python). However, for about 20% of the test data, it gave very large differences between Scala and Python. These differences seem not because of the floating point error between Scala and Python, but they seem from the JSON trees themselves (e.g., Python BoostedTreesRegression.predict method gave 0.48 for a predicted value but in Scala we got 0.18). This behavior frequently appears if we increase the number of trees (the maxiterations parameter, e.g, 100) and depth (the maxdepth parameter, e.g, 10) in the model. Here we attach a binary exported model with maxdepth = 4 and max_iterations = 100. We also attach an SFrame that contains one entry that yields the difference between Python and Scala. In Python, you can load the model and the SFrame into Graphlab like this :
import graphlab as gl
model = gl.load_model( 'e1618.model' )
sf = gl.load_sframe( 'data1618.sframe' )
pred = model.predict( sf ) # get the predicted result
json_array = model.get( 'trees_json' )
(Our Graphlab version is 1.8.1)
We traced these trees and computed the predicted value, which is 0.03910099999999994 but model.predict gave 0.059862732887268066 (in this case we only have depth = 4 so the error rate is still small, but it is significant). We also have the Python code to parse the JSON trees but we cannot paste it here because it is quite complicated.
Could anyone please look into this and show us what happens with our method? Any help is greatly appreciated!