Major changes: - Remove PRAW dependency and Reddit API credentials - Add no-OAuth Reddit scraper using public .json endpoints - No Reddit API keys required - simpler setup! New scraper features: - Uses Reddit's public .json endpoints (www.reddit.com/r/subreddit.json) - Configurable rate limiting via request_delay setting - Automatic retry with exponential backoff - Fetches posts and comments without authentication Files changed: - reddit/scraper.py (new) - No-OAuth Reddit scraper - reddit/subreddit.py - Updated to use scraper instead of PRAW - requirements.txt - Removed praw dependency - utils/.config.template.toml - Removed Reddit credentials - config.example.toml - Updated with scraper settings - docker-entrypoint.sh - Updated for no-auth setup - docker-compose.yml - Removed Reddit credential env vars - main.py - Updated exception handling Limitations: - Subject to Reddit's rate limiting (configurable delay) - ~1000 post cap per subreddit listing - Some comments may be missing in large threads https://claude.ai/code/session_01HLLH3WjpmRzvaoY6eYSFADpull/2456/head
parent
94d8e45cf7
commit
cd9f9f5b40
@ -0,0 +1,506 @@
|
||||
"""
|
||||
No-OAuth Reddit scraper using public .json endpoints.
|
||||
No API keys required - uses Reddit's public JSON interface.
|
||||
|
||||
Note: This approach is subject to rate limiting and may be blocked by Reddit.
|
||||
For production use, consider using the official Reddit API with OAuth.
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from utils.console import print_substep
|
||||
|
||||
|
||||
# Default User-Agent - customize this to avoid rate limiting
|
||||
DEFAULT_USER_AGENT = "python:reddit_video_bot:1.0 (no-oauth scraper)"
|
||||
|
||||
# Reddit base URLs
|
||||
REDDIT_BASES = ["https://www.reddit.com", "https://old.reddit.com"]
|
||||
|
||||
|
||||
class RedditScraperError(Exception):
|
||||
"""Exception raised for Reddit scraper errors."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedditPost:
|
||||
"""Represents a Reddit post/submission."""
|
||||
id: str
|
||||
name: str # t3_xxx
|
||||
title: str
|
||||
selftext: str
|
||||
author: str
|
||||
created_utc: float
|
||||
score: int
|
||||
upvote_ratio: float
|
||||
num_comments: int
|
||||
permalink: str
|
||||
url: str
|
||||
over_18: bool
|
||||
stickied: bool
|
||||
subreddit: str
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: Dict[str, Any]) -> "RedditPost":
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
name=data.get("name", ""),
|
||||
title=data.get("title", ""),
|
||||
selftext=data.get("selftext", ""),
|
||||
author=data.get("author", "[deleted]"),
|
||||
created_utc=float(data.get("created_utc", 0)),
|
||||
score=int(data.get("score", 0)),
|
||||
upvote_ratio=float(data.get("upvote_ratio", 0)),
|
||||
num_comments=int(data.get("num_comments", 0)),
|
||||
permalink=data.get("permalink", ""),
|
||||
url=data.get("url", ""),
|
||||
over_18=bool(data.get("over_18", False)),
|
||||
stickied=bool(data.get("stickied", False)),
|
||||
subreddit=data.get("subreddit", ""),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedditComment:
|
||||
"""Represents a Reddit comment."""
|
||||
id: str
|
||||
name: str # t1_xxx
|
||||
body: str
|
||||
author: str
|
||||
created_utc: float
|
||||
score: int
|
||||
permalink: str
|
||||
parent_id: str
|
||||
link_id: str
|
||||
depth: int
|
||||
stickied: bool
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data: Dict[str, Any], depth: int = 0) -> "RedditComment":
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
name=data.get("name", ""),
|
||||
body=data.get("body", ""),
|
||||
author=data.get("author", "[deleted]"),
|
||||
created_utc=float(data.get("created_utc", 0)),
|
||||
score=int(data.get("score", 0)),
|
||||
permalink=data.get("permalink", ""),
|
||||
parent_id=data.get("parent_id", ""),
|
||||
link_id=data.get("link_id", ""),
|
||||
depth=depth,
|
||||
stickied=bool(data.get("stickied", False)),
|
||||
)
|
||||
|
||||
|
||||
class RedditScraper:
|
||||
"""
|
||||
No-OAuth Reddit scraper using public .json endpoints.
|
||||
|
||||
Example usage:
|
||||
scraper = RedditScraper()
|
||||
posts = scraper.get_subreddit_posts("AskReddit", limit=25, sort="hot")
|
||||
post, comments = scraper.get_post_with_comments(posts[0].id)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
base_url: str = REDDIT_BASES[0],
|
||||
request_delay: float = 2.0,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 5,
|
||||
):
|
||||
"""
|
||||
Initialize the Reddit scraper.
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent string for requests
|
||||
base_url: Reddit base URL (www.reddit.com or old.reddit.com)
|
||||
request_delay: Delay between requests in seconds
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries per request
|
||||
"""
|
||||
self.user_agent = user_agent
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.request_delay = request_delay
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.session = requests.Session()
|
||||
self._last_request_time = 0.0
|
||||
|
||||
def _rate_limit(self) -> None:
|
||||
"""Enforce rate limiting between requests."""
|
||||
elapsed = time.time() - self._last_request_time
|
||||
if elapsed < self.request_delay:
|
||||
time.sleep(self.request_delay - elapsed)
|
||||
self._last_request_time = time.time()
|
||||
|
||||
def _fetch_json(self, url: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Fetch JSON from a Reddit endpoint with retries and rate limiting.
|
||||
|
||||
Args:
|
||||
url: Full URL to fetch
|
||||
params: Query parameters
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
RedditScraperError: If request fails after retries
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": self.user_agent,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
params["raw_json"] = 1 # Get unescaped JSON
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
self._rate_limit()
|
||||
|
||||
try:
|
||||
response = self.session.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
# Handle rate limiting
|
||||
if response.status_code == 429:
|
||||
retry_after = int(response.headers.get("Retry-After", 60))
|
||||
print_substep(f"Rate limited. Waiting {retry_after}s...", style="yellow")
|
||||
time.sleep(max(self.request_delay, retry_after))
|
||||
last_error = RedditScraperError(f"Rate limited (429)")
|
||||
continue
|
||||
|
||||
# Handle server errors
|
||||
if 500 <= response.status_code < 600:
|
||||
wait_time = self.request_delay * (attempt + 1)
|
||||
print_substep(f"Server error {response.status_code}. Retrying in {wait_time}s...", style="yellow")
|
||||
time.sleep(wait_time)
|
||||
last_error = RedditScraperError(f"Server error: {response.status_code}")
|
||||
continue
|
||||
|
||||
# Handle other errors
|
||||
if response.status_code != 200:
|
||||
raise RedditScraperError(
|
||||
f"HTTP {response.status_code}: {response.text[:200]}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
last_error = e
|
||||
wait_time = self.request_delay * (attempt + 1)
|
||||
if attempt < self.max_retries - 1:
|
||||
print_substep(f"Request failed: {e}. Retrying in {wait_time}s...", style="yellow")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
raise RedditScraperError(f"Failed after {self.max_retries} attempts: {last_error}")
|
||||
|
||||
def get_subreddit_posts(
|
||||
self,
|
||||
subreddit: str,
|
||||
sort: str = "hot",
|
||||
limit: int = 25,
|
||||
time_filter: str = "all",
|
||||
after: Optional[str] = None,
|
||||
) -> List[RedditPost]:
|
||||
"""
|
||||
Get posts from a subreddit.
|
||||
|
||||
Args:
|
||||
subreddit: Subreddit name (without r/ prefix)
|
||||
sort: Sort method (hot, new, top, rising, controversial)
|
||||
limit: Maximum number of posts to retrieve (max 100 per request)
|
||||
time_filter: Time filter for top/controversial (hour, day, week, month, year, all)
|
||||
after: Pagination cursor (fullname of last item)
|
||||
|
||||
Returns:
|
||||
List of RedditPost objects
|
||||
"""
|
||||
# Clean subreddit name
|
||||
subreddit = subreddit.strip()
|
||||
if subreddit.lower().startswith("r/"):
|
||||
subreddit = subreddit[2:]
|
||||
|
||||
url = f"{self.base_url}/r/{subreddit}/{sort}.json"
|
||||
params: Dict[str, Any] = {"limit": min(limit, 100)}
|
||||
|
||||
if sort in ("top", "controversial"):
|
||||
params["t"] = time_filter
|
||||
if after:
|
||||
params["after"] = after
|
||||
|
||||
data = self._fetch_json(url, params)
|
||||
|
||||
posts = []
|
||||
children = data.get("data", {}).get("children", [])
|
||||
|
||||
for child in children:
|
||||
if child.get("kind") != "t3":
|
||||
continue
|
||||
post_data = child.get("data", {})
|
||||
if post_data:
|
||||
posts.append(RedditPost.from_json(post_data))
|
||||
|
||||
return posts
|
||||
|
||||
def get_post_by_id(self, post_id: str) -> Optional[RedditPost]:
|
||||
"""
|
||||
Get a single post by ID.
|
||||
|
||||
Args:
|
||||
post_id: Post ID (without t3_ prefix)
|
||||
|
||||
Returns:
|
||||
RedditPost object or None if not found
|
||||
"""
|
||||
# Remove t3_ prefix if present
|
||||
if post_id.startswith("t3_"):
|
||||
post_id = post_id[3:]
|
||||
|
||||
url = f"{self.base_url}/comments/{post_id}.json"
|
||||
params = {"limit": 0} # Don't fetch comments
|
||||
|
||||
try:
|
||||
data = self._fetch_json(url, params)
|
||||
except RedditScraperError:
|
||||
return None
|
||||
|
||||
if not isinstance(data, list) or len(data) < 1:
|
||||
return None
|
||||
|
||||
post_listing = data[0]
|
||||
children = post_listing.get("data", {}).get("children", [])
|
||||
|
||||
if not children:
|
||||
return None
|
||||
|
||||
post_data = children[0].get("data", {})
|
||||
return RedditPost.from_json(post_data) if post_data else None
|
||||
|
||||
def get_post_with_comments(
|
||||
self,
|
||||
post_id: str,
|
||||
comment_sort: str = "top",
|
||||
comment_limit: int = 500,
|
||||
comment_depth: int = 10,
|
||||
max_comments: int = 1000,
|
||||
) -> Tuple[Optional[RedditPost], List[RedditComment]]:
|
||||
"""
|
||||
Get a post with its comments.
|
||||
|
||||
Args:
|
||||
post_id: Post ID (without t3_ prefix)
|
||||
comment_sort: Comment sort (top, new, controversial, best, old, qa)
|
||||
comment_limit: Number of comments per request (max ~500)
|
||||
comment_depth: Maximum depth of comment tree
|
||||
max_comments: Hard cap on total comments to return
|
||||
|
||||
Returns:
|
||||
Tuple of (RedditPost, List[RedditComment])
|
||||
"""
|
||||
# Remove t3_ prefix if present
|
||||
if post_id.startswith("t3_"):
|
||||
post_id = post_id[3:]
|
||||
|
||||
url = f"{self.base_url}/comments/{post_id}.json"
|
||||
params = {
|
||||
"sort": comment_sort,
|
||||
"limit": min(comment_limit, 500),
|
||||
"depth": comment_depth,
|
||||
}
|
||||
|
||||
data = self._fetch_json(url, params)
|
||||
|
||||
if not isinstance(data, list) or len(data) < 2:
|
||||
raise RedditScraperError(f"Unexpected response format for post {post_id}")
|
||||
|
||||
# Parse post
|
||||
post_listing = data[0]
|
||||
post_children = post_listing.get("data", {}).get("children", [])
|
||||
|
||||
if not post_children:
|
||||
return None, []
|
||||
|
||||
post_data = post_children[0].get("data", {})
|
||||
post = RedditPost.from_json(post_data) if post_data else None
|
||||
|
||||
# Parse comments
|
||||
comment_listing = data[1]
|
||||
comment_children = comment_listing.get("data", {}).get("children", [])
|
||||
|
||||
comments: List[RedditComment] = []
|
||||
self._flatten_comments(comment_children, depth=0, out=comments, max_comments=max_comments)
|
||||
|
||||
return post, comments
|
||||
|
||||
def _flatten_comments(
|
||||
self,
|
||||
children: List[Dict[str, Any]],
|
||||
depth: int,
|
||||
out: List[RedditComment],
|
||||
max_comments: int,
|
||||
) -> None:
|
||||
"""
|
||||
Recursively flatten comment tree into a list.
|
||||
|
||||
Ignores "more" placeholders - some comments may be missing in large threads.
|
||||
"""
|
||||
for child in children:
|
||||
if len(out) >= max_comments:
|
||||
return
|
||||
|
||||
kind = child.get("kind")
|
||||
data = child.get("data", {})
|
||||
|
||||
if kind == "t1":
|
||||
# This is a comment
|
||||
comment = RedditComment.from_json(data, depth=depth)
|
||||
out.append(comment)
|
||||
|
||||
# Process replies
|
||||
replies = data.get("replies")
|
||||
if isinstance(replies, dict):
|
||||
reply_children = replies.get("data", {}).get("children", [])
|
||||
if reply_children:
|
||||
self._flatten_comments(
|
||||
reply_children,
|
||||
depth=depth + 1,
|
||||
out=out,
|
||||
max_comments=max_comments,
|
||||
)
|
||||
|
||||
elif kind == "more":
|
||||
# "More comments" placeholder - skip (some comments will be missing)
|
||||
continue
|
||||
|
||||
def search_subreddit(
|
||||
self,
|
||||
subreddit: str,
|
||||
query: str,
|
||||
sort: str = "relevance",
|
||||
time_filter: str = "all",
|
||||
limit: int = 25,
|
||||
) -> List[RedditPost]:
|
||||
"""
|
||||
Search posts in a subreddit.
|
||||
|
||||
Args:
|
||||
subreddit: Subreddit name
|
||||
query: Search query
|
||||
sort: Sort method (relevance, hot, top, new, comments)
|
||||
time_filter: Time filter (hour, day, week, month, year, all)
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of matching posts
|
||||
"""
|
||||
subreddit = subreddit.strip()
|
||||
if subreddit.lower().startswith("r/"):
|
||||
subreddit = subreddit[2:]
|
||||
|
||||
url = f"{self.base_url}/r/{subreddit}/search.json"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"t": time_filter,
|
||||
"limit": min(limit, 100),
|
||||
"restrict_sr": "on", # Restrict to subreddit
|
||||
}
|
||||
|
||||
data = self._fetch_json(url, params)
|
||||
|
||||
posts = []
|
||||
children = data.get("data", {}).get("children", [])
|
||||
|
||||
for child in children:
|
||||
if child.get("kind") != "t3":
|
||||
continue
|
||||
post_data = child.get("data", {})
|
||||
if post_data:
|
||||
posts.append(RedditPost.from_json(post_data))
|
||||
|
||||
return posts
|
||||
|
||||
def get_posts_newer_than(
|
||||
self,
|
||||
subreddit: str,
|
||||
days: int = 30,
|
||||
max_posts: int = 1000,
|
||||
) -> List[RedditPost]:
|
||||
"""
|
||||
Get posts from a subreddit newer than a specified number of days.
|
||||
|
||||
Note: Reddit listings are capped at ~1000 posts. If the subreddit has
|
||||
more posts than this in the time window, older posts will be missed.
|
||||
|
||||
Args:
|
||||
subreddit: Subreddit name
|
||||
days: Number of days to look back
|
||||
max_posts: Maximum posts to retrieve
|
||||
|
||||
Returns:
|
||||
List of posts within the time window
|
||||
"""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
cutoff_ts = cutoff.timestamp()
|
||||
|
||||
all_posts: List[RedditPost] = []
|
||||
after: Optional[str] = None
|
||||
|
||||
while len(all_posts) < max_posts:
|
||||
posts = self.get_subreddit_posts(
|
||||
subreddit=subreddit,
|
||||
sort="new",
|
||||
limit=100,
|
||||
after=after,
|
||||
)
|
||||
|
||||
if not posts:
|
||||
break
|
||||
|
||||
for post in posts:
|
||||
# Skip stickied posts (they can be old)
|
||||
if post.stickied:
|
||||
continue
|
||||
|
||||
if post.created_utc < cutoff_ts:
|
||||
# Reached posts older than cutoff
|
||||
return all_posts
|
||||
|
||||
all_posts.append(post)
|
||||
|
||||
if len(all_posts) >= max_posts:
|
||||
return all_posts
|
||||
|
||||
# Set pagination cursor
|
||||
after = posts[-1].name
|
||||
|
||||
return all_posts
|
||||
|
||||
|
||||
# Global scraper instance
|
||||
_scraper: Optional[RedditScraper] = None
|
||||
|
||||
|
||||
def get_scraper() -> RedditScraper:
|
||||
"""Get or create the global Reddit scraper instance."""
|
||||
global _scraper
|
||||
if _scraper is None:
|
||||
_scraper = RedditScraper()
|
||||
return _scraper
|
||||
@ -1,160 +1,283 @@
|
||||
"""
|
||||
Reddit subreddit thread fetcher using no-OAuth scraper.
|
||||
No API keys required - uses Reddit's public JSON endpoints.
|
||||
"""
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
import praw
|
||||
from praw.models import MoreComments
|
||||
from prawcore.exceptions import ResponseException
|
||||
|
||||
from reddit.scraper import get_scraper, RedditPost, RedditComment, RedditScraperError
|
||||
from utils import settings
|
||||
from utils.ai_methods import sort_by_similarity
|
||||
from utils.console import print_step, print_substep
|
||||
from utils.posttextparser import posttextparser
|
||||
from utils.subreddit import get_subreddit_undone
|
||||
from utils.videos import check_done
|
||||
from utils.voice import sanitize_text
|
||||
|
||||
|
||||
def get_subreddit_threads(POST_ID: str):
|
||||
"""
|
||||
Returns a list of threads from the AskReddit subreddit.
|
||||
class SubmissionWrapper:
|
||||
"""Wrapper to make RedditPost compatible with existing utility functions."""
|
||||
|
||||
def __init__(self, post: RedditPost):
|
||||
self.id = post.id
|
||||
self.title = post.title
|
||||
self.selftext = post.selftext
|
||||
self.author = post.author
|
||||
self.score = post.score
|
||||
self.upvote_ratio = post.upvote_ratio
|
||||
self.num_comments = post.num_comments
|
||||
self.permalink = post.permalink
|
||||
self.url = post.url
|
||||
self.over_18 = post.over_18
|
||||
self.stickied = post.stickied
|
||||
self.subreddit_name = post.subreddit
|
||||
self._post = post
|
||||
|
||||
def to_post(self) -> RedditPost:
|
||||
return self._post
|
||||
|
||||
|
||||
def get_subreddit_threads(POST_ID: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetches a Reddit thread and its comments using the no-OAuth scraper.
|
||||
No API keys required.
|
||||
|
||||
print_substep("Logging into Reddit.")
|
||||
Args:
|
||||
POST_ID: Optional specific post ID to fetch
|
||||
|
||||
content = {}
|
||||
if settings.config["reddit"]["creds"]["2fa"]:
|
||||
print("\nEnter your two-factor authentication code from your authenticator app.\n")
|
||||
code = input("> ")
|
||||
print()
|
||||
pw = settings.config["reddit"]["creds"]["password"]
|
||||
passkey = f"{pw}:{code}"
|
||||
else:
|
||||
passkey = settings.config["reddit"]["creds"]["password"]
|
||||
username = settings.config["reddit"]["creds"]["username"]
|
||||
if str(username).casefold().startswith("u/"):
|
||||
username = username[2:]
|
||||
try:
|
||||
reddit = praw.Reddit(
|
||||
client_id=settings.config["reddit"]["creds"]["client_id"],
|
||||
client_secret=settings.config["reddit"]["creds"]["client_secret"],
|
||||
user_agent="Accessing Reddit threads",
|
||||
username=username,
|
||||
passkey=passkey,
|
||||
check_for_async=False,
|
||||
)
|
||||
except ResponseException as e:
|
||||
if e.response.status_code == 401:
|
||||
print("Invalid credentials - please check them in config.toml")
|
||||
except:
|
||||
print("Something went wrong...")
|
||||
Returns:
|
||||
Dictionary containing thread data and comments
|
||||
"""
|
||||
print_substep("Connecting to Reddit (no-auth mode)...")
|
||||
|
||||
# Ask user for subreddit input
|
||||
print_step("Getting subreddit threads...")
|
||||
scraper = get_scraper()
|
||||
content: Dict[str, Any] = {}
|
||||
similarity_score = 0
|
||||
if not settings.config["reddit"]["thread"][
|
||||
"subreddit"
|
||||
]: # note to user. you can have multiple subreddits via reddit.subreddit("redditdev+learnpython")
|
||||
try:
|
||||
subreddit = reddit.subreddit(
|
||||
re.sub(r"r\/", "", input("What subreddit would you like to pull from? "))
|
||||
# removes the r/ from the input
|
||||
)
|
||||
except ValueError:
|
||||
subreddit = reddit.subreddit("askreddit")
|
||||
|
||||
# Get subreddit from config or user input
|
||||
print_step("Getting subreddit threads...")
|
||||
|
||||
subreddit_name = settings.config["reddit"]["thread"].get("subreddit", "")
|
||||
|
||||
if not subreddit_name:
|
||||
subreddit_name = input("What subreddit would you like to pull from? ")
|
||||
subreddit_name = re.sub(r"^r/", "", subreddit_name.strip())
|
||||
if not subreddit_name:
|
||||
subreddit_name = "AskReddit"
|
||||
print_substep("Subreddit not defined. Using AskReddit.")
|
||||
else:
|
||||
sub = settings.config["reddit"]["thread"]["subreddit"]
|
||||
print_substep(f"Using subreddit: r/{sub} from TOML config")
|
||||
subreddit_choice = sub
|
||||
if str(subreddit_choice).casefold().startswith("r/"): # removes the r/ from the input
|
||||
subreddit_choice = subreddit_choice[2:]
|
||||
subreddit = reddit.subreddit(subreddit_choice)
|
||||
|
||||
if POST_ID: # would only be called if there are multiple queued posts
|
||||
submission = reddit.submission(id=POST_ID)
|
||||
|
||||
elif (
|
||||
settings.config["reddit"]["thread"]["post_id"]
|
||||
and len(str(settings.config["reddit"]["thread"]["post_id"]).split("+")) == 1
|
||||
):
|
||||
submission = reddit.submission(id=settings.config["reddit"]["thread"]["post_id"])
|
||||
elif settings.config["ai"]["ai_similarity_enabled"]: # ai sorting based on comparison
|
||||
threads = subreddit.hot(limit=50)
|
||||
keywords = settings.config["ai"]["ai_similarity_keywords"].split(",")
|
||||
keywords = [keyword.strip() for keyword in keywords]
|
||||
# Reformat the keywords for printing
|
||||
keywords_print = ", ".join(keywords)
|
||||
print(f"Sorting threads by similarity to the given keywords: {keywords_print}")
|
||||
threads, similarity_scores = sort_by_similarity(threads, keywords)
|
||||
submission, similarity_score = get_subreddit_undone(
|
||||
threads, subreddit, similarity_scores=similarity_scores
|
||||
)
|
||||
else:
|
||||
threads = subreddit.hot(limit=25)
|
||||
submission = get_subreddit_undone(threads, subreddit)
|
||||
# Clean the subreddit name
|
||||
if str(subreddit_name).lower().startswith("r/"):
|
||||
subreddit_name = subreddit_name[2:]
|
||||
print_substep(f"Using subreddit: r/{subreddit_name} from config")
|
||||
|
||||
# Get the submission
|
||||
submission: Optional[RedditPost] = None
|
||||
|
||||
try:
|
||||
if POST_ID:
|
||||
# Specific post ID provided (for queued posts)
|
||||
submission = scraper.get_post_by_id(POST_ID)
|
||||
if not submission:
|
||||
raise RedditScraperError(f"Could not find post with ID: {POST_ID}")
|
||||
|
||||
elif settings.config["reddit"]["thread"].get("post_id"):
|
||||
# Post ID from config (single post)
|
||||
post_id = str(settings.config["reddit"]["thread"]["post_id"])
|
||||
if "+" not in post_id: # Single post, not multiple
|
||||
submission = scraper.get_post_by_id(post_id)
|
||||
if not submission:
|
||||
raise RedditScraperError(f"Could not find post with ID: {post_id}")
|
||||
|
||||
elif settings.config["ai"].get("ai_similarity_enabled"):
|
||||
# AI sorting based on keyword similarity
|
||||
print_substep("Fetching posts for AI similarity sorting...")
|
||||
posts = scraper.get_subreddit_posts(subreddit_name, sort="hot", limit=50)
|
||||
|
||||
if not posts:
|
||||
raise RedditScraperError(f"No posts found in r/{subreddit_name}")
|
||||
|
||||
keywords = settings.config["ai"].get("ai_similarity_keywords", "").split(",")
|
||||
keywords = [keyword.strip() for keyword in keywords if keyword.strip()]
|
||||
|
||||
if keywords:
|
||||
keywords_print = ", ".join(keywords)
|
||||
print_substep(f"Sorting threads by similarity to: {keywords_print}")
|
||||
|
||||
# Convert posts to format expected by sort_by_similarity
|
||||
wrappers = [SubmissionWrapper(post) for post in posts]
|
||||
sorted_wrappers, similarity_scores = sort_by_similarity(wrappers, keywords)
|
||||
|
||||
submission, similarity_score = _get_undone_post(
|
||||
sorted_wrappers, subreddit_name, similarity_scores=similarity_scores
|
||||
)
|
||||
else:
|
||||
wrappers = [SubmissionWrapper(post) for post in posts]
|
||||
submission = _get_undone_post(wrappers, subreddit_name)
|
||||
|
||||
else:
|
||||
# Default: get hot posts
|
||||
posts = scraper.get_subreddit_posts(subreddit_name, sort="hot", limit=25)
|
||||
|
||||
if not posts:
|
||||
raise RedditScraperError(f"No posts found in r/{subreddit_name}")
|
||||
|
||||
wrappers = [SubmissionWrapper(post) for post in posts]
|
||||
submission = _get_undone_post(wrappers, subreddit_name)
|
||||
|
||||
except RedditScraperError as e:
|
||||
print_substep(f"Error fetching Reddit data: {e}", style="bold red")
|
||||
raise
|
||||
|
||||
if submission is None:
|
||||
return get_subreddit_threads(POST_ID) # submission already done. rerun
|
||||
print_substep("No suitable submission found. Retrying...", style="yellow")
|
||||
return get_subreddit_threads(POST_ID)
|
||||
|
||||
elif not submission.num_comments and settings.config["settings"]["storymode"] == "false":
|
||||
print_substep("No comments found. Skipping.")
|
||||
# Check if story mode with no comments is okay
|
||||
if not submission.num_comments and not settings.config["settings"].get("storymode"):
|
||||
print_substep("No comments found. Skipping.", style="bold red")
|
||||
exit()
|
||||
|
||||
submission = check_done(submission) # double-checking
|
||||
# Double-check if this post was already done
|
||||
wrapper = SubmissionWrapper(submission)
|
||||
checked = check_done(wrapper)
|
||||
if checked is None:
|
||||
print_substep("Post already processed. Finding another...", style="yellow")
|
||||
return get_subreddit_threads(POST_ID)
|
||||
|
||||
# Display post info
|
||||
upvotes = submission.score
|
||||
ratio = submission.upvote_ratio * 100
|
||||
num_comments = submission.num_comments
|
||||
threadurl = f"https://new.reddit.com/{submission.permalink}"
|
||||
thread_url = f"https://new.reddit.com{submission.permalink}"
|
||||
|
||||
print_substep(f"Video will be: {submission.title} :thumbsup:", style="bold green")
|
||||
print_substep(f"Thread url is: {threadurl} :thumbsup:", style="bold green")
|
||||
print_substep(f"Video will be: {submission.title}", style="bold green")
|
||||
print_substep(f"Thread url is: {thread_url}", style="bold green")
|
||||
print_substep(f"Thread has {upvotes} upvotes", style="bold blue")
|
||||
print_substep(f"Thread has a upvote ratio of {ratio}%", style="bold blue")
|
||||
print_substep(f"Thread has a upvote ratio of {ratio:.0f}%", style="bold blue")
|
||||
print_substep(f"Thread has {num_comments} comments", style="bold blue")
|
||||
|
||||
if similarity_score:
|
||||
print_substep(
|
||||
f"Thread has a similarity score up to {round(similarity_score * 100)}%",
|
||||
style="bold blue",
|
||||
)
|
||||
|
||||
content["thread_url"] = threadurl
|
||||
# Build content dictionary
|
||||
content["thread_url"] = thread_url
|
||||
content["thread_title"] = submission.title
|
||||
content["thread_id"] = submission.id
|
||||
content["is_nsfw"] = submission.over_18
|
||||
content["subreddit"] = subreddit_name
|
||||
content["comments"] = []
|
||||
if settings.config["settings"]["storymode"]:
|
||||
if settings.config["settings"]["storymodemethod"] == 1:
|
||||
|
||||
if settings.config["settings"].get("storymode"):
|
||||
# Story mode - use the post's selftext
|
||||
if settings.config["settings"].get("storymodemethod") == 1:
|
||||
content["thread_post"] = posttextparser(submission.selftext)
|
||||
else:
|
||||
content["thread_post"] = submission.selftext
|
||||
else:
|
||||
for top_level_comment in submission.comments:
|
||||
if isinstance(top_level_comment, MoreComments):
|
||||
continue
|
||||
# Comment mode - fetch and process comments
|
||||
print_substep("Fetching comments...", style="bold blue")
|
||||
|
||||
try:
|
||||
_, comments = scraper.get_post_with_comments(
|
||||
submission.id,
|
||||
comment_sort="top",
|
||||
comment_limit=500,
|
||||
max_comments=1000,
|
||||
)
|
||||
|
||||
# Filter and process comments
|
||||
max_len = int(settings.config["reddit"]["thread"].get("max_comment_length", 500))
|
||||
min_len = int(settings.config["reddit"]["thread"].get("min_comment_length", 1))
|
||||
|
||||
for comment in comments:
|
||||
# Skip non-top-level comments (depth > 0)
|
||||
if comment.depth > 0:
|
||||
continue
|
||||
|
||||
# Skip deleted/removed
|
||||
if comment.body in ["[removed]", "[deleted]"]:
|
||||
continue
|
||||
|
||||
if top_level_comment.body in ["[removed]", "[deleted]"]:
|
||||
continue # # see https://github.com/JasonLovesDoggo/RedditVideoMakerBot/issues/78
|
||||
if not top_level_comment.stickied:
|
||||
sanitised = sanitize_text(top_level_comment.body)
|
||||
if not sanitised or sanitised == " ":
|
||||
# Skip stickied comments
|
||||
if comment.stickied:
|
||||
continue
|
||||
if len(top_level_comment.body) <= int(
|
||||
settings.config["reddit"]["thread"]["max_comment_length"]
|
||||
):
|
||||
if len(top_level_comment.body) >= int(
|
||||
settings.config["reddit"]["thread"]["min_comment_length"]
|
||||
):
|
||||
if (
|
||||
top_level_comment.author is not None
|
||||
and sanitize_text(top_level_comment.body) is not None
|
||||
): # if errors occur with this change to if not.
|
||||
content["comments"].append(
|
||||
{
|
||||
"comment_body": top_level_comment.body,
|
||||
"comment_url": top_level_comment.permalink,
|
||||
"comment_id": top_level_comment.id,
|
||||
}
|
||||
)
|
||||
|
||||
print_substep("Received subreddit threads Successfully.", style="bold green")
|
||||
|
||||
# Sanitize and validate
|
||||
sanitized = sanitize_text(comment.body)
|
||||
if not sanitized or sanitized.strip() == "":
|
||||
continue
|
||||
|
||||
# Check length constraints
|
||||
if len(comment.body) > max_len:
|
||||
continue
|
||||
if len(comment.body) < min_len:
|
||||
continue
|
||||
|
||||
# Skip if author is deleted
|
||||
if comment.author in ["[deleted]", "[removed]"]:
|
||||
continue
|
||||
|
||||
content["comments"].append({
|
||||
"comment_body": comment.body,
|
||||
"comment_url": comment.permalink,
|
||||
"comment_id": comment.id,
|
||||
})
|
||||
|
||||
print_substep(f"Collected {len(content['comments'])} valid comments", style="bold green")
|
||||
|
||||
except RedditScraperError as e:
|
||||
print_substep(f"Error fetching comments: {e}", style="yellow")
|
||||
# Continue without comments if fetch fails
|
||||
|
||||
print_substep("Received subreddit threads successfully.", style="bold green")
|
||||
return content
|
||||
|
||||
|
||||
def _get_undone_post(
|
||||
wrappers: List[SubmissionWrapper],
|
||||
subreddit_name: str,
|
||||
similarity_scores: Optional[List[float]] = None,
|
||||
) -> Optional[RedditPost] | Tuple[Optional[RedditPost], float]:
|
||||
"""
|
||||
Find a submission that hasn't been processed yet.
|
||||
|
||||
Args:
|
||||
wrappers: List of SubmissionWrapper objects
|
||||
subreddit_name: Name of the subreddit
|
||||
similarity_scores: Optional similarity scores for each submission
|
||||
|
||||
Returns:
|
||||
First undone RedditPost, or tuple of (RedditPost, similarity_score) if scores provided
|
||||
"""
|
||||
allow_nsfw = settings.config["settings"].get("allow_nsfw", False)
|
||||
min_comments = int(settings.config["reddit"]["thread"].get("min_comments", 20))
|
||||
|
||||
for i, wrapper in enumerate(wrappers):
|
||||
# Skip NSFW if not allowed
|
||||
if wrapper.over_18 and not allow_nsfw:
|
||||
continue
|
||||
|
||||
# Skip stickied posts
|
||||
if wrapper.stickied:
|
||||
continue
|
||||
|
||||
# Check minimum comments (unless story mode)
|
||||
if not settings.config["settings"].get("storymode"):
|
||||
if wrapper.num_comments < min_comments:
|
||||
continue
|
||||
|
||||
# Check if already done
|
||||
if check_done(wrapper) is None:
|
||||
continue
|
||||
|
||||
post = wrapper.to_post()
|
||||
|
||||
if similarity_scores is not None and i < len(similarity_scores):
|
||||
return post, similarity_scores[i]
|
||||
|
||||
return post
|
||||
|
||||
return None
|
||||
|
||||
Loading…
Reference in new issue