Google Cloud Platform – Understanding Federated Learning on Cloud
Crowdsourcing has a wide range of benefits. Whether it’s restaurant reviews that help us find a perfect place for dinner or crowdfunding to bring our favorite TV show back to life, these distributed contributions combined to make some super useful tools.
We can also use that same concept to build better machine learning models. In this article, we are going to examine a different approach to machine learning. Standard machine learning approaches require centralizing the training data into a common store. So let’s say we want to train a keyboard prediction model based on user interactions. Traditionally, we implemented intelligence by collecting all the data on the server, creating a model, and then serving it. Clients all talk to the server to make predictions. The model and the data are all in one central location, making it extremely easy.
But the downside of this centralized setup is that this back-and-forth communication can hurt the user experience due to network latency, connectivity, battery lives, and all sorts of unpredictable issues.
One way to solve this is to have each client independently train its own model using its own data right on the device. No communication is necessary.
But this has a problem as each individual device does not have enough data to render a good model. But you could pre-train the model on the server and then deploy it.
But the problem with that is, in our smart keyboard example, let’s say if everyone started using a new trendy word today, then the model trained on yesterday’s data won’t be as useful.
To utilize the goodness of decentralized data while maintaining users’ privacy we can make use of federated learning. The core idea behind federated learning is decentralized learning, where the user data is never sent to a central server.
You start with a model on the server, distribute it to the clients.
But you can’t just deploy to every client because you don’t want to break the user experience. You will identify these clients based on which ones are available, plugged in, and not in use. Then, also find out which ones are suitable because not all clients will have sufficient data. Once you’ve identified suitable devices, we can then deploy the model to them. Now, each client trains the model locally using its own local data and produces a new model, which is sent to the server.
The thing to know here is that the data used to train the model on the individual device never leaves the device. Only the weights, biases, and other parameters learned by the model leave the device.
The server then gets all the locally trained models and averages them out, effectively creating a new master model.
Doing this once is not enough. We have to do it over and over so the combined model becomes the initial model for the next round. And with every round, the combined model gets a little bit better thanks to the data from all those clients.
If you’ve used Google Keyboard, the Gboard, then you have already seen and experienced a federated learning use case. The Gboard shows a suggested query, your phone locally stores information about the current context and whether you clicked the suggestion.
Federated Learning processes that history on-device to suggest improvements to the next iteration of Gboard’s query suggestion model.