Publications
2024
- Under ReviewTabGrad: Tabular Learning via Critique-Driven Iterative Prompt Refinement with LLMsIn Submission.
Large Language Models (LLMs) demonstrate impressive reasoning and question-solving capabilities. Practitioners elicit reasoning from LLMs by designing creative prompting techniques for inference such as In-Context Learning and Chain-of-Thought. Specifically, a technique like Verbalized Machine Learning (VML) can solve predictive problems by iteratively “backpropagating” textual feedback provided by LLMs. This form of output provides practitioners without prior machine learning knowledge with natural language interpretability, allowing them to trace the optimization progress through the lens of natural language. However, a drawback among these models is that they require resource-intensive LLMs (Llama 70B) to obtain competitive results. Running these LLMs necessitates multiple costly GPUs (e.g. 2\times Nvidia A100 80GB) as they can consume up to 140GB of GPU VRAM. To remedy this, we propose TabGrad, a variant of VML that incorporates novel concepts including “Critiques”, “Patience”, and “Moving Average” to improve the generalization performance of various sized LLMs, especially smaller LLMs such as Gemma 9B. “Critiques” enable the LLM to optimize based on its mistakes, “Patience” ensures that non-beneficial rules do not get included in the prompt during the training process, and “Moving Average” gives the optimized prompt a holistic and balanced overview of the entire training set. Through comprehensive experiments on real-world datasets, we demonstrate that these concepts help improve generalization performance and training stability, increasing predictive accuracy, and reducing variance across multiple test seeds.
- Under ReviewTabUnite: A Universal Categorical Encoding Scheme for Mixed-Type Tabular DataIn Submission.
Flow matching and diffusion generative models for tabular data face challenges in modeling heterogeneous feature interrelationships, especially in data with continuous and categorical input features. Capturing these interrelationships is crucial as it allows these models to understand complex patterns and dependencies in the underlying data. A promising option to address the challenge is to devise suitable encoding schemes for the input features before the generative modeling process. However, prior methods often rely on either suboptimal heuristics such as one-hot encoding of categorical features followed by separated modeling of categorical/continuous features, or latent space diffusion models. Instead, our proposed solution unifies the data space and jointly applies a single generative process across all the encodings, efficiently capturing heterogeneous feature interrelationships. Specifically, it employs encoding schemes such as PSK Encoding, Dictionary Encoding, and Analog Bits that effectively convert categorical features into continuous ones. Extensive experiments on datasets comprised of heterogeneous features demonstrate that our encoding schemes, combined with Flow Matching or Diffusion as our choice of generative model, significantly enhance model capabilities. Our TabUnite models help address data heterogeneity, achieving superior performance across a broad suite of datasets, baselines, and benchmarks while generating accurate, robust, and diverse tabular data.
- ICMLSpotlightInterpreTabNet: Distilling Predictive Signals from Tabular Data by Salient Feature InterpretationJacob Si, Wendy Yusi Cheng, Michael Cooper, and Rahul KrishnanIn the 41st International Conference on Machine Learning, 2024.Spotlight Presentation [top 3.5%]
Tabular data are omnipresent in various sectors of industries. Neural networks for tabular data such as TabNet have been proposed to make predictions while leveraging the attention mechanism for interpretability. However, the inferred attention masks are often dense, making it challenging to come up with rationales about the predictive signal. To remedy this, we propose InterpreTabNet, a variant of the TabNet model that models the attention mechanism as a latent variable sampled from a Gumbel-Softmax distribution. This enables us to regularize the model to learn distinct concepts in the attention masks via a KL Divergence regularizer. It prevents overlapping feature selection by promoting sparsity which maximizes the model’s efficacy and improves interpretability to determine the important features when predicting the outcome. To assist in the interpretation of feature interdependencies from our model, we employ a large language model (GPT-4) and use prompt engineering to map from the learned feature mask onto natural language text describing the learned signal. Through comprehensive experiments on real-world datasets, we demonstrate that InterpreTabNet outperforms previous methods for interpreting tabular data while attaining competitive accuracy.
2022
- Book ChapterAssessing Infant Mortality Rate: Problems stemming from Household Living Conditions, Women’s Education and HealthJacob Si, and Rohan AlexanderIn "Telling Stories with Data: With Applications in R" by Rohan Alexander
What areas can be improved in order to promote the well-being of women in India and hence, reduce the infant mortality rate? Utilizing the data from the 1998-1999 India National Family Health Survey provided by the Demographic and Health Survey (DHS) program, we look to depict the demographics of Indian women and infants in different states of India. We have found that the root causes of poor infant mortality rates stem from having poor living conditions that affect the likelihood of women to attain education and understand the importance of antenatal care and birth delivery assistance. We also explore other factors such as potentially inheritable traits (unhealthy body weight and anaemia disease) as well as an infant’s diet. These factors are crucial in the development of an infant and the reduction of the infant mortality rate.