Mitigating Bias in AI Using Debias-GAN
In This White Paper
Today's AI and Machine Learning (ML) algorithms have achieved spectacular results in automating decisions that were traditionally made by humans. However, the actual data used for model training may be imbalanced and may introduce discriminatory biases towards specific groups of people. Natural Language Processing (NLP) machine learning models are gaining popularity in various contexts such as resume screening, college admission, emotion assessment, repeated crime prediction, and more. Consequently, it becomes increasingly important to recognize the role they play in contributing to societal biases and stereotypes. NLP models trained on historical data often lack optimization for reducing implicit biases, and in some cases, they further perpetuate biases. Bias in machine learning models presents itself as a strong association amongst attributes that ought not be correlated. In this white paper, we propose a general framework, debias-GAN, to address this issue by explicitly augmenting a training dataset for NLP models with underrepresented instances synthesized by a pretrained sequence generating model. As a proof-of-concept, we chose to experiment with a deep classification model that mimics decorrelation between user ethnicity and tweets. The synthetic data is generated by a targeted language model (LM) that generates realistic but user-ethnicity-oblivious tweets. We trained such debiased LMs with generative adversarial networks (GAN) through reinforcement learning (RL) by adding a penalty function term to the loss function, to minimize sequences with strong indication of user ethnicity via a policy update. The reward is provided by an independently trained classifier that identifies user ethnicity from tweets. We experimented with the ratio of mixed datasets and tested the debiasing impact using three fairness metrics. The debias-GAN is able to improve the fairness metrics of the classifier by up to seven times while maintaining classification performance.
Thoughts from WWT's Diversity & Inclusion Group
It is important to take a holistic approach as a society to eliminate racism and drive diversity, inclusion, and equity. As we work to create environments of inclusion, we cannot have the systems and tools we use creating and perpetuating biases. Racism is structural and institutional, and there are explicit (conscious) and implicit (unconscious) biases. With the rapid proliferation of artificial intelligence and the areas in which it can be used, it is critical to debias data so that fair and equitable decisions can be made.
Business use case
Artificial intelligence (AI) has been growing exponentially and playing an ever-more important role in data-driven decision making. It is leveraged for various industries and a multitude of real-life scenarios, including sensitive areas such as recruitment, healthcare (e.g., medical referral and diagnosis), and criminal justice. The broad applications of AI have substantially improved efficiency and reduced cost for companies and governments by automating and optimizing people and processes. At the same time, concerns about AI decisions, as well as unintended consequences have risen, particularly in the natural language processing (NLP)- related tasks. A notable case is with an algorithm widely used in US hospitals that created a situation where African Americans needed to be sicker than their white counterparts before the algorithm recommended more advanced care programs (Ledford, 2019). In another study, researchers found gender biased classifiers for computer-aided diagnosis (Agostina, Nieto, Peterson, Milone, & Ferrante, 2020).
In this paper, we focus on biased data, the core of the bias in AI, and explore the potential for utilizing AI to detect and mitigate biases. To simplify our pilot study, we have additionally chosen to focus on "conversational tweets" vs. "non-conversational tweets" as a first proxy to the presence of the bias and reserve the study of actual protected categories for future work. We hereby define the bias as learned associations between the conversational nature of tweets and the ethnicity of the user. That is, tweets from white users are more frequently associated with conversational tweets compared with those from African American users. The learned associations in the study are similar to explicit and implicit biases in real life because biases oftentimes take the form of false beliefs that certain groups of people always possess certain traits. By training a generator model to produce user-ethnicity-oblivious tweets, we were able to create a more balanced dataset and reduce biases with the real and synthetic tweets. Our study shows the feasibility of using AI to address biased data, one of the many sources of the bias in AI, and our approach could be further applied to real life tasks that are currently plagued by inequities in AI applications.
3.1 Technical background
Several approaches have been proposed for tackling bias in NLP from two different perspectives:
(1) Balancing text corpora and their representations (retraining)
(2) Adjusting algorithm (inference) (Sun, et al., 2019).
Herein, we briefly review each perspective using gender as the bias attribute one targets to remove.
Retraining can be done by balancing the underlying input corpus
Retraining methods require that the model is trained again but with corrected input corpora and word embeddings. To balance the original dataset, counterfactual augmentation, i.e., swapping gender specific terms such as men vs. women, and name anonymization, i.e., replace gender indicative terms with anonymized entity such as "Mary" vs. "E", are frequently used (Zhao, Wang, Yatskar, Vicente, & Chang, 2018). As another text source enhancement technique, gender tagging is proven to be effective in some learning tasks such as Neural Machine Translation (NMT) between gendered language to de-gendered language. On the data level, gender tagging introduces a tag that indicates the gender of the source of the data point to the beginning of every data point, i.e., "[MALE] I am happy" to "Je suis heureux" and "[FEMALE] I am happy" to "Je suis heureuse" (Vanmassenhove, Hardmeier, & Way, 2019).
Another method to retrain is rebalancing the word embeddings
Beyond directly manipulating input text data, another category of retraining involves rebalancing word embeddings. This type of method has been developed over the years since it was initially proposed in 2016. By minimizing the negative difference (i.e., maximizing the difference) between gender dimension in male, female word embeddings and between the gender dimension and other neutral dimensions in the word embedding, this method allows for great flexibility and has been proven to be effective in correcting bias (Zhao, Zhou, Li, Wang, & Chang, 2018). However, to debias LMs which are used for text generation, this method falls short in preserving fidelity of the synthetic sequence.
Retraining is not feasible in case we are using pre-trained models
At the same time, retraining methods are not practical for adjusting pretrained modern LMs such as GPT-3 from open AI. Such models with billions of parameters require a massive amount of corrected training data and computation time. To tackle the shortcomings of retraining methods, algorithm adjustment methods have been developed more recently. One major breakthrough came from the deep learning community for image processing. Generative adversarial network (GAN) is known for synthesizing diverse images with high fidelity (Brock, Donahue, & Simonyan, 2018). By altering the discriminator architecture, it is possible to control specific features for generated images using conditional/controllable GAN (Bo, Fidler, Urtasun, & Lin , 2017).
Algorithm adjustment tweaks the model implementation to debias data
Here, we have a variation of the traditional GAN which has multi-objective optimization, where the generator learns with respect to a protected gender attribute pursuing two goals at the same time: a discriminator and a protected attribute classifier, proposed in 2018 (Ramaswamy, Sunnis, & Russakovsky, 2020). In adversarial learning, instead of using the discriminator to classify "real" vs. "fake" sequence, the discriminator is trained to identify gender in a given task. This approach is generalizable and can be applied to various debiasing use cases and any model that utilizes gradient-descent based training: a classifier, a word embedding model, a language model, etc.
Using SeqGAN with reinforcement learning to generate meaningful real world looking synthetic data and which mitigate ethnicity bias
Two peculiar challenges arise when directly applying an image GAN framework to generating sequences, as in training an LM. Firstly, GAN is designed for generating continuous data (such as pixel values for both gray- and RGB-scale images), not for discrete values (such as a token). Moreover, as sequence generation is approximated with a Markov process, where the discriminator provides feedback to the generator only once a sequence is finished (Bahdanau, et al., 2016). Similarly, to impose an additional discriminator to debias the generator, it is non-trivial to balance feedback for a partially generated sequence. To solve the two problems, a SeqGAN model was proposed in 2017 that trains the generative model as an agent of reinforcement learning (RL) (Yu, Wang, & Yu, 2017). In SeqGAN, the state is the generated tokens at the time, and the action is the next token to be generated. The reward is provided by the discriminator to evaluate the sequence and guide the learning of the generator. In each step, Monte Carlo (MC) search is employed to approximate state-action value. The policy is trained through policy gradient, a heuristic approach that avoids the differentiation difficulty for discrete data. Built upon SeqGAN, the two-step debiasing strategy can be summarized as the following: First, we leveraged a reinforcement learning (RL) framework and trained a tweet generator with a penalty for generating biased tweets as part of adversarial training. Second, we employed the "debiased" generator to synthesize a balanced dataset as the input for the downstream text classifier.
Understanding the ethnicity bias between conversational and non-conversational tweets
In this paper, we aim to mitigate ethnicity bias by classifying whether a tweet is conversational. By "conversational," we are referring to the tweets with "@" mentions. When trained on real tweets, a deep-learning classifier established a strong correlation between the conversational attribute with ethnicity attribute, which we attempted to reduce. We decided to use "conversational tweets" as preliminary proxy of ethnicity for our current experiment. This was done to simplify the essence of the modelling approach as part of this pilot study. Moreover, this also eliminated the need to manually tag the protected categories which would be time consuming and is a part of future work.
3.2 Model architecture
3.3. Fairness metrics
Fairness metrics are measures to capture a model's behavior on different protected classes. There are multiple metrics proposed to measure the fairness of a model, each measuring a different notion of fairness. In this paper, we have used 2 different metrics as our fairness measure. The first metric is Difference in Equality of Opportunity (DEO), which calculates the absolute difference between the false negatives for different target and protected class. The second metric is Bias Amplification (BA), which is the difference between real percentage of protected attribute among target attribute instances and the percentage of protected attribute amongst instances with predicted target attribute. It measures how often a target attribute is predicted with a protected attribute rather than the ground truth.
Explaining the fairness metrics using an example
The example in Figure 5 illustrates a case where we have a dataset with gender attributes – identifying each person as a male or female (Bolukbasi, Chang, Zou, Saligrama, & Kalai, 2016). The problem requires classifying each person as a homemaker/boss. The training data shows an inherent bias where females are less likely (46%) to be bosses compared to males (58%)
(a) Tweet generator model with long short-term memory (LSTM) layer, which consists of multiple LSTM cells. With three trainable gates, forget gate f, update gate i, and output gate o, each LSTM cell is capable of passing long term dependences in terms of cell state c from previous sequence to the next cell. The generator is trained by MLE in supervised learning. (b) Discriminator model with CNN architecture with multiple layers. Tweets have a fixed length of 10 in our case and a sparse matrix is created after converting all tweets to a vector format using a trainable word embedding layer (BERTweet), followed by multiple sets of convolutions and a pooling layer, before a fully connected layer with dropout and SoftMax output. (c) In the adversarial training for debiasing, generator (G), discriminator (D), and classifier (C) are combined under RL framework. On the left, the D is trained over real tweets and generated fake tweets. On the right, G is trained by policy gradient where the final reward signal is provided by both D and pre-trained C. The reward here can be positive or negative (punishment) in value and further passed back to the intermediate action value via Monte Carlo search.
4.1 Data description and implementation
For this paper, we employed a dataset from Blodgett's group as described in "Demographic dialectal variation in social media – Case study of African-American English" (Blodgett, Green, & O'Connor, 2016). In this dataset, 20M tweets from Caucasian (labelled as 'naa') and African American (labelled as 'aa') users was collected through the Tweepy API. The labels for race and @mention for each tweet were applied directly from the existing Uni Weimar PAN16 challenge that tagged tweets based on race (White American, African American). The processed dataset is available from the paper, "Adversarial Removal of Demographic Attributes from Text Data" (Elazar & Goldberg, 2018).
Based on the data as well as our aim to provide a general purpose debias model, we chose the target attribute to be whether a tweet is conversational (proxied by @mention usage). The protected attribute was chosen to be the ethnicity of the twitter user. The goal was to train a classifier to accurately identify conversational tweets without racial bias based on fairness metrics mentioned in the previous section. We employed a long-short term memory (LSTM) based language model (LM) as the generator to synthesize tweets in sequence. The Markov process randomly chose the next word through sampling a learned conditional probability of words, given previous sequence of words. We adopted a CNN architecture for both the discriminator and the classifier for ethnicity (ethnicity classifier).
The real tweets were first pre-processed and tokenized by BERTweet embedding (Dat Quoc Nguyen, 2020), which is trained with 850M tweets. For the proof-of-concept, we selected tweets to be exactly 10 tokens in length (20K tweets from the original dataset) having 40K vocabulary size/unique tokens. These tweets were passed to the pre-trained LM using supervised method by MLE. The discriminator was initialized and pretrained with a mix of synthetic tweets and real tweets of the same text length (10 tokens). The pre-trained generative model was further refined by the SeqGAN architecture. In parallel, we trained an ethnicity classifier on 20M real tweets of various lengths. To debias the LM, after achieving convergence in SeqGAN, we combined the discriminator reward with ethnicity classifier reward in the debias-GAN and iterated until convergence. To benchmark the performance, a classifier with the same CNN structure was used for identifying conversational tweets trained on only real tweets and then mixed with synthetic tweets (@mention classifier) for comparison. The entire model was implemented with TensorFlow and deployed on 4 NVIDIA Tesla K80 GPU in Amazon Web Services (AWS).
4.2.1 GAN convergence and SeqGAN with pre-trained BERTweet model
Initially we encountered some problems with "mode collapsing" in the generator, namely, the generator got stuck in a non-optimal state and only provided a narrow variety of tweets. While first training the LM using an LSTM, we discovered that the MLE loss didn't converge; rather, it fluctuated around a baseline, indicating that the LM isn't improving. Furthermore, when the LM model was passed for adversarial training, we saw that the generator loss (in red) decreased rapidly in the initial stages, but later spiked at different training stages. Moreover, the discriminator loss (in green) fluctuated significantly, indicating mode collapsing for the generator.
To fix this issue, we used the pre-trained language model (BERTweet) for the GAN. We used the real-world tweets as the target variable to make the pre-trained language model understand the kind of output we are expecting when generating tweets. The generator is trained independently first using the real-world tweets and is later improved by using the feedback from the race classifier to generate race oblivious tweets.
After adding additional MLE training, we observed that the loss (MLE) curve in red decreased. As the model improves incrementally after each epoch, the loss steadily decreases. This resulted in reduced generator loss and in turn, generator convergence. Additionally, the discriminator loss was observed to be much smoother than the basic LSTM model, shown in green in Figure 9 below.
(a) GAN training without pre-trained embeddings. Although the min-max game between the generator and discriminator decreases, the MLE loss (blue) fluctuates around the same baseline, indicating the quality of the LM does not necessarily improve. The generator loss (red) decreases rapidly initially but spikes at various stages of training and the discriminator loss (green) fluctuates significantly. Altogether, this behavior indicates mode collapsing for the generator. (b) GAN training with pre-trained embeddings. Due to the additional MLE training, the MLE loss (red) decreases steadily, indicating incremental improvement when training the MLE with real tweets. Interestingly, the MLE training significantly improves the convergence of the generator loss (red), as indicated by the rapid drop marked by red arrows. The discriminator loss (green) is steady and smooth. It is likely that additional MLE training is capable of guiding the GAN training to avoid local extrema and improve the diversity and fidelity of tweets generated.
Explaining mode collapse using an example
If we compare the tweets generated by the two models, we observe that most of the tweets generated without the pre-trained model was primarily dominated by a single token, (@USER), which is representative of the case of mode collapsing. Whereas, for the model with pre-trained embeddings, we can see a diversity in the tokens generated, with improved grammar structure (punctuation, conjugation of words, and choice of symbols and emoji).
Left: With just using a LSTM model, after 200 passes of GAN training, the generated tweets lack diversity. The most common token is @USER, which is representative of mode collapsing. Right: With using a pre-trained model (BERTweet), the generated tweets are diverse (different tokens), and the grammatical structure is improved (punctuation, conjugation of words, and choice of symbols and emoji). Note: The redacted parts in the generated tweets are inappropriate language.
4.2.2 Debias-GAN convergence and improved fairness metrics
We mixed synthesized tweets generated by the generator with real tweets with a varying ratio and used this input for the conversation classifier. For all experiments, the conversation classifier was trained in the exact same manner except for the input dataset. After the classifier is trained, we calculated the fairness metrics on real tweets that are not included in the training data. In most cases, we observed improved fairness metrics (Figure 13, BA, DEOs) with different mix ratios and moderate decrease in the model classification performance (Figure 13, AUC: area under the curve). Interestingly, the impact of the number of synthetic tweets in the input data on the fairness metrics seems to be nonmonotonic and differs from metric to metric. Overall, with 5% synthetic tweets, we were able to achieve a significant debiasing on nearly all metrics (with very minor increase in BA for all tweets), followed by only 1% synthetic tweets.
For the baseline performance, the classifier is trained on real tweets only. When the real tweets were mixed with 1%, 5% and 10% synthetic tweets generated from debiased generator, various impacts are observed. With the current experiment, 5% mix ratio provides the best debiasing result while maintaining good classification performance.
Model trained on the real tweets further amplified the bias in the real data by capturing a lower % of conversational tweets by African Americans in its prediction. The performance improved when the model was trained on a combination of real and synthetic data and tested on the same set of real tweets.
We have presented a general framework to debias a machine learning model for a broad range of natural language processing tasks by augmenting the input data using a mixture of real and synthetic sequences generated by a specially tuned language model. As a proof of concept, we focused on a classifier that identifies whether a tweet is conversational and reduced the learned association between the conversational nature of the tweet with the ethnicity of the user. The learned associations defined in this study reflect explicit and implicit biases in a larger context, since the conversational nature of the tweet serves as a proxy of any traits that can be incorrectly correlated with specific populations. We trained a generator model to synthesize ethnicity oblivious tweets through a GAN via policy gradient, the reward of which is provided by another pretrained ethnicity classifier. We further trained the conversational classifier with input data as a mixture of real and synthetic tweets and compared the model classification performance as well as fairness with the baseline model trained with only real tweets. With different mixture ratios, we observe a moderate decrease in classification performance, but the best mixing ratio generated as much as a seven-fold improvement in model fairness metrics. Our study provides a solid example of utilizing AI to fix biased input data that correlates conversational tweets and user ethnicity. Likewise, the same approach can also be applied in real life tasks to combat biases in AI, especially those that result from biased training data.
For future work, the debiasing method can be improved in the following directions:
- Mix synthetic data of varying sequence length with real data as input to the downstream natural language process models.
- An improved language model to generate more realistic and diverse synthetic data. To meet this end, a complicated model could be leveraged, such as GPT 2 or 3 by open AI. The goal is to fine tune the last few layers of the model rather than train it from scratch.
- Within the debias-GAN, a Monte Carlo tree search and value network can be implemented instead of the current rollout strategy to improve action decision making in long-term planning (Silver, et al., 2016).
- It would be interesting to see the SeqGAN performance for foreign language tweets. We tried using Mandarin Weibo tweets, and code has been attached in the repository.
- Experiment with novel reward within the debias-GAN to achieve better convergence. We expect the debias-GAN framework to be popular and widely applied in a broad range of use cases, and truly improve AI decision making and promote diversity and inclusion.
1) Agostina, L., Nieto, N., Peterson, V., Milone, D. H., & Ferrante, E. (2020, June 9). Gender imbalance in medical imaging datasets produces biased classifiers for computer-aided diagnosis. Retrieved from PNAS: https://www.pnas.org/content/117/23/12592
2) Angelina Wang, O. R. (2021). Directional Bias Amplification. arXiv preprint arXiv:2102.12594.
3) Bahdanau, D., Brakel, P., Xu, K., Goyal, A., Lowe, R., Pineau, J., . . . Bengio, Y. (2016). An actor-critic algorithm for sequence prediction. arXiv preprint, 1607.07086.
4) Blodgett, S. L., Green, L., & O'Connor, B. (2016). Demographic Dialectal Variation in Social Media: A Case Study of African-American English. arXiv:1608.08868, 15.
5) Bo, D., Fidler, S., Urtasun, R., & Lin , D. (2017). Towards diverse and natural image descriptions via a conditional gan. Proceedings of the IEEE International Conference on Computer Vision, 2970-2979.
6) Bolukbasi, T., Chang, K.-W., Zou, J., Saligrama, V., & Kalai, A. (2016). Man is to Computer Programmer as Woman is to Homemaker? Advances in neural information processing systems 29.
7) Brock, A., Donahue, J., & Simonyan, K. (2018). Large scale GAN training for high idelity natural image synthesis. arXiv preprint, 1809.11096.
8) Crawford, K. (2017). The Trouble With Bias. Keynote at Neural Information Processing Systems. NIPS'17.
9) Dat Quoc Nguyen, T. V. (2020). BERTweet: A pre-trained language model for English Tweets. arXiv preprint arXiv.
10) Elazar, Y., & Goldberg, Y. (2018). Adversarial removal of demographic attributes from text data. arXiv:1808.06640.
11) Hao, K. (2019, February 4). This is how AI bias really happens--and why it's so hard to fix. Retrieved from Technology Review: https://www.technologyreview.com/2019/02/04/137602/this-is-how-ai-bias-really-happensand-why-its-so-hard-to-fix/
12) Ledford, H. (2019, October 24). Millions of black people affected by racial bias in health-care algorithms. Retrieved from nature: https://www.nature.com/articles/d41586-019-03228-6
13) Liu, R., Jia, C., Wei, J., Xu, G., Wang, L., & Vosoughi, S. (2021). Mitigating political bias in language models through reinforced calibration. Proceedings of the AAAI Conference on Artificial Intelligence.
14) Ramaswamy, V. V., Sunnis, K. S., & Russakovsky, O. (2020). Fair Attribute Classification through Latent Space De-biasing. arXiv preprint, 2012.01469.
15) Roselli, D., Matthews, J., & Talagala, N. (2019, May). Managing Bias in AI. Retrieved from ACM Digital Library: https://dl.acm.org/doi/abs/10.1145/3308560.3317590
16) Silver, D., Huang, A., Maddison, C. J., Guez, A., Sifre, L., Van Den Dreissche, G., . . . Sutsk. (2016). Mastering the game of go with deep neural networks and tree search. Nature, 484-489.
17) Sun, T., Gaut, A., Tang, S., Hunag, Y., ElSherief, M., Zhao, J., . . . Wang, W. Y. (2019). Mitgating gender bias in natural language processing: Literature review. arXiv preprint, 1906.08976.
18) Sutton, R. S., McAllester, D. A., Singh, S. P., & Mansour, Y. (1999). Policy gradient methods for reinforcement learning with function approximation. NIPs, 1057-1063.
19) Vanmassenhove, E., Hardmeier, C., & Way, A. (2019). Getting gender right in neural machine translation. arXiv preprint, 1909.05088.
20) Yu, L., Wang, W., & Yu, Y. (2017). Seqgan: Sequence generative adversarial nets with policy gradient. Proceedings of the AAAI conference on artificial intelligence, vol. 31, no. 1.
21) Zhao, J., Wang, T., Yatskar, M., Vicente, O., & Chang, K.-W. (2018). Gender bias in coreference resolution: Evaluation and debiasing methods. arXiv preprint, 1804.06876.
22) Zhao, J., Zhou, Y., Li, Z., Wang, W., & Chang, K.-W. (2018). Learning Gender-Neutral Word Embeddings. arXiv preprint, 1809.01496.