Machine Learning Classification Deep Learning Python

Machine Learning with Reddit Data: Classification, Prediction, and Beyond

By @ml_practitioner | February 16, 2026 | 25 min read

Reddit data offers rich opportunities for machine learning: predicting post virality, classifying content, detecting spam, and building recommendation systems. This guide covers practical ML applications with production-ready code and model evaluation strategies.

Learning Objectives

Feature engineering from Reddit data, classification and regression models, handling imbalanced datasets, model evaluation and deployment considerations.

ML Applications for Reddit Data

Application Task Type Target Variable Common Algorithms
Virality Prediction Regression/Classification Upvotes, engagement tier XGBoost, Neural Networks
Subreddit Classification Multi-class Classification Subreddit name BERT, FastText, SVM
Spam Detection Binary Classification Spam/Not spam Random Forest, LSTM
User Intent Multi-class Classification Question, Opinion, News, etc. BERT, Logistic Regression
Content Recommendation Ranking Relevance score Collaborative Filtering, Matrix Factorization

Feature Engineering

Good features are critical for Reddit ML models. Here's a comprehensive feature extraction pipeline:

import pandas as pd
import numpy as np
from datetime import datetime
from textblob import TextBlob
import re

class RedditFeatureExtractor:
    """Extract ML features from Reddit posts."""

    def __init__(self):
        self.features = []

    def extract_text_features(self, text: str) -> dict:
        """Extract text-based features."""
        if not text:
            text = ""

        # Basic text stats
        words = text.split()
        word_count = len(words)
        char_count = len(text)
        avg_word_length = char_count / word_count if word_count > 0 else 0

        # Sentence count
        sentences = re.split(r'[.!?]+', text)
        sentence_count = len([s for s in sentences if s.strip()])

        # Special elements
        url_count = len(re.findall(r'https?://\S+', text))
        mention_count = len(re.findall(r'u/\w+', text))
        subreddit_refs = len(re.findall(r'r/\w+', text))

        # Questions and formatting
        question_marks = text.count('?')
        exclamation_marks = text.count('!')
        caps_ratio = sum(1 for c in text if c.isupper()) / char_count if char_count > 0 else 0

        # Sentiment features
        blob = TextBlob(text)
        polarity = blob.sentiment.polarity
        subjectivity = blob.sentiment.subjectivity

        return {
            'word_count': word_count,
            'char_count': char_count,
            'avg_word_length': avg_word_length,
            'sentence_count': sentence_count,
            'url_count': url_count,
            'mention_count': mention_count,
            'subreddit_refs': subreddit_refs,
            'question_marks': question_marks,
            'exclamation_marks': exclamation_marks,
            'caps_ratio': caps_ratio,
            'sentiment_polarity': polarity,
            'sentiment_subjectivity': subjectivity
        }

    def extract_temporal_features(self, timestamp: float) -> dict:
        """Extract time-based features."""
        dt = datetime.utcfromtimestamp(timestamp)

        return {
            'hour': dt.hour,
            'day_of_week': dt.weekday(),
            'day_of_month': dt.day,
            'month': dt.month,
            'is_weekend': 1 if dt.weekday() >= 5 else 0,
            'is_business_hours': 1 if 9 <= dt.hour <= 17 else 0,
            'is_peak_hours': 1 if dt.hour in [10, 11, 14, 15, 20, 21] else 0
        }

    def extract_engagement_features(self, post: dict) -> dict:
        """Extract engagement-related features."""
        score = post.get('score', 0)
        num_comments = post.get('num_comments', 0)
        upvote_ratio = post.get('upvote_ratio', 0.5)

        return {
            'score': score,
            'num_comments': num_comments,
            'upvote_ratio': upvote_ratio,
            'score_log': np.log1p(max(score, 0)),
            'comments_log': np.log1p(num_comments),
            'comment_score_ratio': num_comments / (score + 1) if score > 0 else 0
        }

    def extract_all_features(self, post: dict) -> dict:
        """Extract all features from a post."""
        features = {}

        # Title features
        title_features = self.extract_text_features(post.get('title', ''))
        features.update({f'title_{k}': v for k, v in title_features.items()})

        # Body features (if selftext exists)
        body_features = self.extract_text_features(post.get('selftext', ''))
        features.update({f'body_{k}': v for k, v in body_features.items()})

        # Temporal features
        if post.get('created_utc'):
            temporal = self.extract_temporal_features(post['created_utc'])
            features.update(temporal)

        # Engagement features
        engagement = self.extract_engagement_features(post)
        features.update(engagement)

        # Post type features
        features['is_self'] = 1 if post.get('is_self') else 0
        features['is_video'] = 1 if post.get('is_video') else 0
        features['over_18'] = 1 if post.get('over_18') else 0
        features['has_thumbnail'] = 1 if post.get('thumbnail') not in ['self', 'default', '', None] else 0

        return features

    def extract_batch(self, posts: List[dict]) -> pd.DataFrame:
        """Extract features from multiple posts."""
        all_features = [self.extract_all_features(post) for post in posts]
        return pd.DataFrame(all_features)

# Usage
extractor = RedditFeatureExtractor()
features_df = extractor.extract_batch(reddit_posts)
print(f"Extracted {len(features_df.columns)} features")

Virality Prediction Model

Predict whether a post will go viral based on early features:

$ pip install xgboost scikit-learn optuna
Successfully installed xgboost-2.0.3 scikit-learn-1.4.0
import xgboost as xgb
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.preprocessing import StandardScaler
import optuna

class ViralityPredictor:
    """Predict post virality using XGBoost."""

    def __init__(self, threshold: int = 1000):
        self.threshold = threshold  # Score threshold for "viral"
        self.model = None
        self.scaler = StandardScaler()
        self.feature_names = None

    def prepare_data(self, features_df: pd.DataFrame) -> tuple:
        """Prepare features and target variable."""
        # Create binary target: 1 if score > threshold
        y = (features_df['score'] >= self.threshold).astype(int)

        # Remove target-related features for prediction
        exclude_cols = ['score', 'score_log', 'num_comments',
                        'comments_log', 'comment_score_ratio', 'upvote_ratio']
        X = features_df.drop(columns=[c for c in exclude_cols if c in features_df.columns])

        self.feature_names = X.columns.tolist()

        return X, y

    def train(self, X: pd.DataFrame, y: pd.Series, optimize: bool = False):
        """Train the virality prediction model."""
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )

        # Scale features
        X_train_scaled = self.scaler.fit_transform(X_train)
        X_test_scaled = self.scaler.transform(X_test)

        if optimize:
            best_params = self._optimize_hyperparameters(X_train_scaled, y_train)
        else:
            best_params = {
                'max_depth': 6,
                'learning_rate': 0.1,
                'n_estimators': 200,
                'min_child_weight': 1,
                'subsample': 0.8,
                'colsample_bytree': 0.8
            }

        # Handle class imbalance
        scale_pos_weight = len(y_train[y_train == 0]) / len(y_train[y_train == 1])

        self.model = xgb.XGBClassifier(
            **best_params,
            scale_pos_weight=scale_pos_weight,
            random_state=42,
            eval_metric='auc'
        )

        self.model.fit(
            X_train_scaled, y_train,
            eval_set=[(X_test_scaled, y_test)],
            verbose=50
        )

        # Evaluate
        y_pred = self.model.predict(X_test_scaled)
        y_proba = self.model.predict_proba(X_test_scaled)[:, 1]

        print("Classification Report:")
        print(classification_report(y_test, y_pred))
        print(f"ROC-AUC: {roc_auc_score(y_test, y_proba):.4f}")

        return self.model

    def _optimize_hyperparameters(self, X_train, y_train, n_trials: int = 50):
        """Optimize hyperparameters using Optuna."""

        def objective(trial):
            params = {
                'max_depth': trial.suggest_int('max_depth', 3, 10),
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3),
                'n_estimators': trial.suggest_int('n_estimators', 100, 500),
                'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
                'subsample': trial.suggest_float('subsample', 0.6, 1.0),
                'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0)
            }

            model = xgb.XGBClassifier(**params, random_state=42)
            scores = cross_val_score(model, X_train, y_train, cv=5, scoring='roc_auc')
            return scores.mean()

        study = optuna.create_study(direction='maximize')
        study.optimize(objective, n_trials=n_trials, show_progress_bar=True)

        print(f"Best AUC: {study.best_value:.4f}")
        return study.best_params

    def get_feature_importance(self) -> pd.DataFrame:
        """Get feature importance rankings."""
        importance = self.model.feature_importances_

        df = pd.DataFrame({
            'feature': self.feature_names,
            'importance': importance
        }).sort_values('importance', ascending=False)

        return df

    def predict(self, features: pd.DataFrame) -> np.ndarray:
        """Predict virality for new posts."""
        features_scaled = self.scaler.transform(features)
        return self.model.predict_proba(features_scaled)[:, 1]

# Usage
predictor = ViralityPredictor(threshold=500)
X, y = predictor.prepare_data(features_df)
predictor.train(X, y, optimize=True)

# Get feature importance
importance = predictor.get_feature_importance()
print(importance.head(10))
0.82
ROC-AUC Score
76%
Precision
68%
Recall
72%
F1 Score

Text Classification with Transformers

Classify Reddit posts into subreddits or content categories:

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import Trainer, TrainingArguments
from datasets import Dataset
from sklearn.preprocessing import LabelEncoder

class SubredditClassifier:
    """Classify posts into subreddits using BERT."""

    def __init__(self, model_name: str = "distilbert-base-uncased"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = None
        self.label_encoder = LabelEncoder()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_name = model_name

    def prepare_data(self, df: pd.DataFrame) -> tuple:
        """Prepare data for training."""
        # Combine title and selftext
        texts = df.apply(
            lambda x: f"{x['title']} {x.get('selftext', '')}".strip(),
            axis=1
        ).tolist()

        # Encode labels
        labels = self.label_encoder.fit_transform(df['subreddit'])

        return texts, labels

    def create_dataset(self, texts: List[str], labels: np.ndarray) -> Dataset:
        """Create Hugging Face Dataset."""

        def tokenize(examples):
            return self.tokenizer(
                examples['text'],
                truncation=True,
                padding='max_length',
                max_length=256
            )

        data = {'text': texts, 'label': labels.tolist()}
        dataset = Dataset.from_dict(data)
        dataset = dataset.map(tokenize, batched=True)

        return dataset

    def train(
        self,
        train_texts: List[str],
        train_labels: np.ndarray,
        val_texts: List[str],
        val_labels: np.ndarray,
        output_dir: str = "./subreddit-classifier"
    ):
        """Train the classifier."""
        num_labels = len(self.label_encoder.classes_)

        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=num_labels
        )

        train_dataset = self.create_dataset(train_texts, train_labels)
        val_dataset = self.create_dataset(val_texts, val_labels)

        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=3,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=32,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir="./logs",
            logging_steps=100,
            eval_strategy="steps",
            eval_steps=500,
            save_steps=1000,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy"
        )

        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            predictions = np.argmax(logits, axis=-1)
            accuracy = (predictions == labels).mean()
            return {'accuracy': accuracy}

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=compute_metrics
        )

        trainer.train()
        return trainer

    @torch.no_grad()
    def predict(self, texts: List[str]) -> List[str]:
        """Predict subreddit for new posts."""
        self.model.eval()
        self.model.to(self.device)

        inputs = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=256,
            return_tensors="pt"
        ).to(self.device)

        outputs = self.model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()

        return self.label_encoder.inverse_transform(predictions)

# Usage
classifier = SubredditClassifier()
texts, labels = classifier.prepare_data(df)

# Split and train
X_train, X_val, y_train, y_val = train_test_split(texts, labels, test_size=0.2)
classifier.train(X_train, y_train, X_val, y_val)

Class Imbalance Warning

Reddit data often has severe class imbalance (many posts vs few viral posts, popular vs niche subreddits). Use techniques like: class weights, SMOTE oversampling, focal loss, or stratified sampling to address this.

Handling Imbalanced Data

from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline

class ImbalancedDataHandler:
    """Handle class imbalance in Reddit ML tasks."""

    def analyze_distribution(self, y: pd.Series) -> dict:
        """Analyze class distribution."""
        counts = y.value_counts()
        total = len(y)

        return {
            'counts': counts.to_dict(),
            'percentages': (counts / total * 100).to_dict(),
            'imbalance_ratio': counts.max() / counts.min()
        }

    def get_class_weights(self, y: np.ndarray) -> dict:
        """Calculate class weights for balanced learning."""
        from sklearn.utils.class_weight import compute_class_weight

        classes = np.unique(y)
        weights = compute_class_weight(
            class_weight='balanced',
            classes=classes,
            y=y
        )

        return dict(zip(classes, weights))

    def resample_data(
        self,
        X: np.ndarray,
        y: np.ndarray,
        strategy: str = 'combined'
    ) -> tuple:
        """Resample data to balance classes."""

        if strategy == 'oversample':
            sampler = SMOTE(random_state=42)
        elif strategy == 'undersample':
            sampler = RandomUnderSampler(random_state=42)
        elif strategy == 'combined':
            # Oversample minority, undersample majority
            sampler = Pipeline([
                ('over', SMOTE(sampling_strategy=0.5, random_state=42)),
                ('under', RandomUnderSampler(sampling_strategy=0.8, random_state=42))
            ])
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        X_resampled, y_resampled = sampler.fit_resample(X, y)

        print(f"Original: {len(y)} samples")
        print(f"Resampled: {len(y_resampled)} samples")

        return X_resampled, y_resampled

# Usage
handler = ImbalancedDataHandler()
print(handler.analyze_distribution(y))

# Resample if needed
X_balanced, y_balanced = handler.resample_data(X, y, strategy='combined')

Model Evaluation Framework

Metric Use When Interpretation
Accuracy Balanced classes Overall correct predictions
Precision False positives costly How many positive predictions are correct
Recall False negatives costly How many actual positives are found
F1 Score Imbalanced classes Balance of precision and recall
ROC-AUC Ranking quality matters Model's discrimination ability
PR-AUC Highly imbalanced data Better than ROC for rare positives

Pro Tip: Use Multiple Metrics

Don't rely on a single metric. For viral post detection, you might care more about recall (catching all viral posts) while for spam detection, precision (avoiding false positives) may be more important.

Skip the ML Pipeline

reddapi.dev provides pre-built AI classification and analysis. Get sentiment scores, topic categories, and relevance rankings without training your own models.

Try AI-Powered Analysis

Production Deployment

import joblib
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List

app = FastAPI()

class PredictionRequest(BaseModel):
    title: str
    selftext: str = ""
    created_utc: float

class PredictionResponse(BaseModel):
    virality_score: float
    is_viral: bool
    confidence: float

# Load models at startup
@app.on_event("startup")
async def load_models():
    global predictor, extractor
    predictor = joblib.load("models/virality_predictor.joblib")
    extractor = RedditFeatureExtractor()

@app.post("/predict", response_model=PredictionResponse)
async def predict_virality(request: PredictionRequest):
    try:
        # Extract features
        post = {
            'title': request.title,
            'selftext': request.selftext,
            'created_utc': request.created_utc,
            'score': 0,  # Placeholder
            'num_comments': 0
        }

        features = extractor.extract_all_features(post)
        features_df = pd.DataFrame([features])

        # Remove engagement features for prediction
        exclude = ['score', 'num_comments', 'upvote_ratio',
                   'score_log', 'comments_log', 'comment_score_ratio']
        features_df = features_df.drop(columns=[c for c in exclude if c in features_df])

        # Predict
        score = predictor.predict(features_df)[0]

        return PredictionResponse(
            virality_score=float(score),
            is_viral=score > 0.5,
            confidence=abs(score - 0.5) * 2
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Run with: uvicorn api:app --host 0.0.0.0 --port 8000

Frequently Asked Questions

What features matter most for virality prediction?

Based on extensive research, key features include: posting time (day/hour), title characteristics (length, question marks, emotional words), subreddit popularity, author karma/history, and presence of media (images, videos). Temporal features are often most predictive since they capture when active users are online.

How much training data do I need?

For traditional ML (XGBoost, Random Forest): 10k-50k samples work well. For deep learning (BERT): 50k-100k+ samples are recommended. For rare events (viral posts), you may need more data or synthetic augmentation. Quality and balance matter as much as quantity.

Should I use traditional ML or deep learning?

Start with traditional ML (XGBoost, LightGBM) for tabular features—they're fast, interpretable, and often perform well. Use deep learning (BERT, transformers) when text is your primary signal or when you have abundant data. Ensemble approaches combining both often yield best results.

How do I handle temporal data leakage?

For prediction tasks, ensure your training data doesn't include features that wouldn't be available at prediction time. For virality prediction, you can't use engagement metrics (upvotes, comments) that occur after posting. Use time-based train/test splits rather than random splits.

How often should I retrain models?

Reddit trends change rapidly. Monitor model performance weekly and retrain when metrics degrade (typically monthly). For production systems, implement online learning or scheduled retraining pipelines. Track concept drift by comparing prediction distributions over time.