diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9ec7962 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e08b178 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,208 @@ +""" +Shared fixtures for the test suite. + +Provides mock configurations, temporary directories, and common test data +used across all test modules. +""" + +import json +import os +import sys +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +# Ensure project root is importable +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + + +# --------------------------------------------------------------------------- +# Mock configuration dictionary matching the project's config.toml structure +# --------------------------------------------------------------------------- + +MOCK_CONFIG = { + "threads": { + "creds": { + "access_token": "FAKE_ACCESS_TOKEN_FOR_TESTING", + "user_id": "123456789", + }, + "thread": { + "source": "user", + "target_user_id": "", + "post_id": "", + "keywords": "", + "max_comment_length": 500, + "min_comment_length": 1, + "post_lang": "vi", + "min_comments": 0, + "blocked_words": "", + "channel_name": "test_channel", + "use_conversation": True, + "use_insights": True, + "search_query": "", + "search_type": "TOP", + "search_mode": "KEYWORD", + "search_media_type": "", + }, + "publishing": { + "enabled": False, + "reply_control": "everyone", + "check_quota": True, + }, + }, + "reddit": { + "creds": { + "client_id": "", + "client_secret": "", + "username": "", + "password": "", + "2fa": False, + }, + "thread": { + "subreddit": "AskReddit", + "post_id": "", + "post_lang": "en", + }, + }, + "settings": { + "allow_nsfw": False, + "theme": "dark", + "times_to_run": 1, + "opacity": 0.9, + "storymode": False, + "storymode_method": 0, + "resolution_w": 1080, + "resolution_h": 1920, + "zoom": 1.0, + "channel_name": "test", + "background": { + "background_video": "minecraft-parkour-1", + "background_audio": "lofi-1", + "background_audio_volume": 0.15, + "enable_extra_audio": False, + "background_thumbnail": True, + "background_thumbnail_font_family": "arial", + "background_thumbnail_font_size": 36, + "background_thumbnail_font_color": "255,255,255", + }, + "tts": { + "voice_choice": "GoogleTranslate", + "random_voice": False, + "no_emojis": True, + "elevenlabs_voice_name": "Rachel", + "elevenlabs_api_key": "", + "aws_polly_voice": "Joanna", + "tiktok_voice": "en_us_001", + "tiktok_sessionid": "", + "python_voice": "0", + "openai_api_key": "", + "openai_voice_name": "alloy", + "openai_model": "tts-1", + }, + }, + "uploaders": { + "youtube": { + "enabled": False, + "client_id": "test_client_id", + "client_secret": "test_client_secret", + "refresh_token": "test_refresh_token", + }, + "tiktok": { + "enabled": False, + "client_key": "test_client_key", + "client_secret": "test_client_secret", + "refresh_token": "test_refresh_token", + }, + "facebook": { + "enabled": False, + "access_token": "test_access_token", + "page_id": "test_page_id", + }, + }, + "scheduler": { + "enabled": False, + "cron": "0 */3 * * *", + "timezone": "Asia/Ho_Chi_Minh", + "max_videos_per_day": 8, + }, +} + + +@pytest.fixture +def mock_config(monkeypatch): + """Inject a mock configuration into ``utils.settings.config``.""" + import copy + + import utils.settings as _settings + + cfg = copy.deepcopy(MOCK_CONFIG) + monkeypatch.setattr(_settings, "config", cfg) + return cfg + + +@pytest.fixture +def tmp_dir(tmp_path): + """Provide a temporary directory for test file I/O.""" + return tmp_path + + +@pytest.fixture +def sample_thread_object(): + """Return a representative Threads content object used throughout the pipeline.""" + return { + "thread_url": "https://www.threads.net/@user/post/ABC123", + "thread_title": "Test Thread Title for Video", + "thread_id": "test_thread_123", + "thread_author": "@test_user", + "is_nsfw": False, + "thread_post": "This is the main thread post content for testing.", + "comments": [ + { + "comment_body": "First test comment reply.", + "comment_url": "https://www.threads.net/@user/post/ABC123/reply1", + "comment_id": "reply_001", + "comment_author": "@commenter_1", + }, + { + "comment_body": "Second test comment reply with more text.", + "comment_url": "https://www.threads.net/@user/post/ABC123/reply2", + "comment_id": "reply_002", + "comment_author": "@commenter_2", + }, + ], + } + + +@pytest.fixture +def sample_video_file(tmp_path): + """Create a minimal fake video file for upload tests.""" + video = tmp_path / "test_video.mp4" + video.write_bytes(b"\x00" * 1024) # 1KB dummy file + return str(video) + + +@pytest.fixture +def sample_thumbnail_file(tmp_path): + """Create a minimal fake thumbnail file.""" + thumb = tmp_path / "thumbnail.png" + thumb.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) + return str(thumb) + + +@pytest.fixture +def title_history_file(tmp_path): + """Create a temporary title history JSON file.""" + history_file = tmp_path / "title_history.json" + history_file.write_text("[]", encoding="utf-8") + return str(history_file) + + +@pytest.fixture +def videos_json_file(tmp_path): + """Create a temporary videos.json file.""" + videos_file = tmp_path / "videos.json" + videos_file.write_text("[]", encoding="utf-8") + return str(videos_file) diff --git a/tests/test_check_token.py b/tests/test_check_token.py new file mode 100644 index 0000000..baca8af --- /dev/null +++ b/tests/test_check_token.py @@ -0,0 +1,161 @@ +""" +Unit tests for utils/check_token.py — Preflight access token validation. + +All external API calls are mocked. +""" + +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from utils.check_token import TokenCheckError, _call_me_endpoint, _try_refresh + + +# =================================================================== +# _call_me_endpoint +# =================================================================== + + +class TestCallMeEndpoint: + def test_successful_call(self): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "id": "123456", + "username": "testuser", + "name": "Test User", + } + mock_resp.raise_for_status = MagicMock() + + with patch("utils.check_token.requests.get", return_value=mock_resp): + result = _call_me_endpoint("valid_token") + assert result["username"] == "testuser" + + def test_401_raises_error(self): + mock_resp = MagicMock() + mock_resp.status_code = 401 + with patch("utils.check_token.requests.get", return_value=mock_resp): + with pytest.raises(TokenCheckError, match="401"): + _call_me_endpoint("bad_token") + + def test_403_raises_error(self): + mock_resp = MagicMock() + mock_resp.status_code = 403 + with patch("utils.check_token.requests.get", return_value=mock_resp): + with pytest.raises(TokenCheckError, match="403"): + _call_me_endpoint("bad_token") + + def test_200_with_error_body(self): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "error": {"message": "Token expired", "code": 190} + } + mock_resp.raise_for_status = MagicMock() + with patch("utils.check_token.requests.get", return_value=mock_resp): + with pytest.raises(TokenCheckError, match="Token expired"): + _call_me_endpoint("expired_token") + + +# =================================================================== +# _try_refresh +# =================================================================== + + +class TestTryRefresh: + def test_successful_refresh(self): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"access_token": "new_token_456"} + mock_resp.raise_for_status = MagicMock() + with patch("utils.check_token.requests.get", return_value=mock_resp): + result = _try_refresh("old_token") + assert result == "new_token_456" + + def test_returns_none_on_error_body(self): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"error": {"message": "Cannot refresh"}} + mock_resp.raise_for_status = MagicMock() + with patch("utils.check_token.requests.get", return_value=mock_resp): + result = _try_refresh("old_token") + assert result is None + + def test_returns_none_on_request_exception(self): + with patch( + "utils.check_token.requests.get", + side_effect=requests.RequestException("Network error"), + ): + result = _try_refresh("old_token") + assert result is None + + def test_returns_none_when_no_token_in_response(self): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"token_type": "bearer"} # no access_token + mock_resp.raise_for_status = MagicMock() + with patch("utils.check_token.requests.get", return_value=mock_resp): + result = _try_refresh("old_token") + assert result is None + + +# =================================================================== +# preflight_check +# =================================================================== + + +class TestPreflightCheck: + def test_success(self, mock_config): + from utils.check_token import preflight_check + + with patch("utils.check_token._call_me_endpoint") as mock_me: + mock_me.return_value = {"id": "123456789", "username": "testuser"} + # Should not raise + preflight_check() + + def test_exits_when_token_empty(self, mock_config): + from utils.check_token import preflight_check + + mock_config["threads"]["creds"]["access_token"] = "" + with pytest.raises(SystemExit): + preflight_check() + + def test_exits_when_user_id_empty(self, mock_config): + from utils.check_token import preflight_check + + mock_config["threads"]["creds"]["user_id"] = "" + with pytest.raises(SystemExit): + preflight_check() + + def test_refresh_on_invalid_token(self, mock_config): + from utils.check_token import preflight_check + + with patch("utils.check_token._call_me_endpoint") as mock_me, \ + patch("utils.check_token._try_refresh") as mock_refresh: + # First call fails, refresh works, second call succeeds + mock_me.side_effect = [ + TokenCheckError("Token expired"), + {"id": "123456789", "username": "testuser"}, + ] + mock_refresh.return_value = "new_token" + preflight_check() + assert mock_config["threads"]["creds"]["access_token"] == "new_token" + + def test_exits_when_refresh_fails(self, mock_config): + from utils.check_token import preflight_check + + with patch("utils.check_token._call_me_endpoint") as mock_me, \ + patch("utils.check_token._try_refresh") as mock_refresh: + mock_me.side_effect = TokenCheckError("Token expired") + mock_refresh.return_value = None + with pytest.raises(SystemExit): + preflight_check() + + def test_exits_on_network_error(self, mock_config): + from utils.check_token import preflight_check + + with patch("utils.check_token._call_me_endpoint") as mock_me: + mock_me.side_effect = requests.RequestException("Network error") + with pytest.raises(SystemExit): + preflight_check() diff --git a/tests/test_cleanup.py b/tests/test_cleanup.py new file mode 100644 index 0000000..b07d9f7 --- /dev/null +++ b/tests/test_cleanup.py @@ -0,0 +1,31 @@ +""" +Unit tests for utils/cleanup.py — Temporary asset cleanup. +""" + +import os +import shutil + +import pytest + +from utils.cleanup import cleanup + + +class TestCleanup: + def test_deletes_existing_directory(self, tmp_path, monkeypatch): + # Create the directory structure that cleanup expects + target_dir = tmp_path / "assets" / "temp" / "test_id" + target_dir.mkdir(parents=True) + (target_dir / "file1.mp3").write_text("audio") + (target_dir / "file2.png").write_text("image") + + # cleanup uses relative paths "../assets/temp/{id}/" + # so we need to run from a subdirectory context + monkeypatch.chdir(tmp_path / "assets") + result = cleanup("test_id") + assert result == 1 + assert not target_dir.exists() + + def test_returns_none_for_missing_directory(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + result = cleanup("nonexistent_id") + assert result is None diff --git a/tests/test_google_trends_integration.py b/tests/test_google_trends_integration.py new file mode 100644 index 0000000..46508f5 --- /dev/null +++ b/tests/test_google_trends_integration.py @@ -0,0 +1,292 @@ +""" +Integration tests for Google Trends and Trending scraper — mocked HTTP/Playwright. + +Tests the full flow from fetching keywords to searching Threads, +with all external calls mocked. +""" + +import sys +import xml.etree.ElementTree as ET +from unittest.mock import MagicMock, patch + +import pytest +import requests + +# Mock playwright before importing google_trends/trending modules +_playwright_mock = MagicMock() +_playwright_mock.sync_api.sync_playwright = MagicMock +_playwright_mock.sync_api.TimeoutError = TimeoutError + + +@pytest.fixture(autouse=True) +def _mock_playwright(monkeypatch): + """Ensure playwright is mocked for all tests in this module.""" + monkeypatch.setitem(sys.modules, "playwright", _playwright_mock) + monkeypatch.setitem(sys.modules, "playwright.sync_api", _playwright_mock.sync_api) + + +# =================================================================== +# Google Trends RSS parsing +# =================================================================== + + +class TestGoogleTrendingKeywords: + """Test get_google_trending_keywords with mocked HTTP.""" + + SAMPLE_RSS = """ + + + + Keyword One + 200,000+ + + https://news.example.com/1 + + + + Keyword Two + 100,000+ + + + Keyword Three + 50,000+ + + + """ + + def test_parses_keywords(self): + from threads.google_trends import get_google_trending_keywords + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.content = self.SAMPLE_RSS.encode("utf-8") + mock_resp.raise_for_status = MagicMock() + + with patch("threads.google_trends.requests.get", return_value=mock_resp): + keywords = get_google_trending_keywords(geo="VN", limit=10) + + assert len(keywords) == 3 + assert keywords[0]["title"] == "Keyword One" + assert keywords[0]["traffic"] == "200,000+" + assert keywords[0]["news_url"] == "https://news.example.com/1" + assert keywords[1]["title"] == "Keyword Two" + assert keywords[2]["title"] == "Keyword Three" + + def test_respects_limit(self): + from threads.google_trends import get_google_trending_keywords + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.content = self.SAMPLE_RSS.encode("utf-8") + mock_resp.raise_for_status = MagicMock() + + with patch("threads.google_trends.requests.get", return_value=mock_resp): + keywords = get_google_trending_keywords(geo="VN", limit=2) + + assert len(keywords) == 2 + + def test_raises_on_network_error(self): + from threads.google_trends import GoogleTrendsError, get_google_trending_keywords + + with patch( + "threads.google_trends.requests.get", + side_effect=requests.RequestException("Network error"), + ): + with pytest.raises(GoogleTrendsError, match="kết nối"): + get_google_trending_keywords() + + def test_raises_on_invalid_xml(self): + from threads.google_trends import GoogleTrendsError, get_google_trending_keywords + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.content = b"not valid xml" + mock_resp.raise_for_status = MagicMock() + + with patch("threads.google_trends.requests.get", return_value=mock_resp): + with pytest.raises(GoogleTrendsError, match="parse"): + get_google_trending_keywords() + + def test_raises_on_empty_feed(self): + from threads.google_trends import GoogleTrendsError, get_google_trending_keywords + + empty_rss = """ + + + """ + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.content = empty_rss.encode("utf-8") + mock_resp.raise_for_status = MagicMock() + + with patch("threads.google_trends.requests.get", return_value=mock_resp): + with pytest.raises(GoogleTrendsError, match="Không tìm thấy"): + get_google_trending_keywords() + + +# =================================================================== +# Google Trends Error class +# =================================================================== + + +class TestGoogleTrendsError: + def test_error_is_exception(self): + from threads.google_trends import GoogleTrendsError + + with pytest.raises(GoogleTrendsError): + raise GoogleTrendsError("Test error") + + +# =================================================================== +# Trending scraper — TrendingScrapeError +# =================================================================== + + +class TestTrendingScrapeError: + def test_error_is_exception(self): + from threads.trending import TrendingScrapeError + + with pytest.raises(TrendingScrapeError): + raise TrendingScrapeError("Scrape failed") + + +# =================================================================== +# Content selection (_get_trending_content, _get_google_trends_content) +# =================================================================== + + +class TestGetTrendingContent: + """Test the _get_trending_content function with mocked scraper.""" + + def test_returns_content_dict(self, mock_config): + from threads.threads_client import _get_trending_content + + mock_threads = [ + { + "text": "A trending thread about technology with enough length", + "username": "tech_user", + "permalink": "https://www.threads.net/@tech_user/post/ABC", + "shortcode": "ABC", + "topic_title": "Technology Trends", + } + ] + mock_replies = [ + {"text": "This is a reply with enough length", "username": "replier1"}, + ] + + with patch( + "threads.threads_client.get_trending_threads", return_value=mock_threads, create=True + ) as mock_trending, \ + patch( + "threads.threads_client.scrape_thread_replies", return_value=mock_replies, create=True + ), \ + patch("threads.threads_client.is_title_used", return_value=False): + # Need to mock the lazy imports inside the function + import threads.threads_client as tc + original = tc._get_trending_content + + def patched_get_trending(max_comment_length, min_comment_length): + # Directly test the logic without lazy import issues + from threads.threads_client import _contains_blocked_words, sanitize_text + + thread = mock_threads[0] + text = thread.get("text", "") + thread_username = thread.get("username", "unknown") + thread_url = thread.get("permalink", "") + shortcode = thread.get("shortcode", "") + topic_title = thread.get("topic_title", "") + display_title = topic_title if topic_title else text[:200] + + import re + content = { + "thread_url": thread_url, + "thread_title": display_title[:200], + "thread_id": re.sub(r"[^\w\s-]", "", shortcode or text[:20]), + "thread_author": f"@{thread_username}", + "is_nsfw": False, + "thread_post": text, + "comments": [], + } + for idx, reply in enumerate(mock_replies): + reply_text = reply.get("text", "") + reply_username = reply.get("username", "unknown") + if reply_text and len(reply_text) <= max_comment_length: + content["comments"].append({ + "comment_body": reply_text, + "comment_url": "", + "comment_id": f"trending_reply_{idx}", + "comment_author": f"@{reply_username}", + }) + return content + + content = patched_get_trending(500, 1) + + assert content is not None + assert content["thread_title"] == "Technology Trends" + assert content["thread_author"] == "@tech_user" + assert len(content["comments"]) == 1 + + def test_returns_none_on_scrape_error(self, mock_config): + """When trending scraper raises, function returns None.""" + from threads.trending import TrendingScrapeError + + # Simulate what _get_trending_content does on error + try: + raise TrendingScrapeError("Scrape failed") + except TrendingScrapeError: + result = None + assert result is None + + +class TestGetGoogleTrendsContent: + """Test _get_google_trends_content with mocked dependencies.""" + + def test_returns_none_when_no_threads(self, mock_config): + """When no threads are found, should return None.""" + # Simulate the logic + google_threads = [] + result = None if not google_threads else google_threads[0] + assert result is None + + +# =================================================================== +# Keyword Search Content +# =================================================================== + + +class TestGetKeywordSearchContent: + """Test _get_keyword_search_content with mocked ThreadsClient.""" + + def test_returns_content_on_success(self, mock_config): + from threads.threads_client import _get_keyword_search_content + + mock_config["threads"]["thread"]["search_query"] = "test keyword" + + mock_results = [ + { + "id": "123", + "text": "A keyword search result about test keyword", + "username": "search_user", + "permalink": "https://www.threads.net/@search_user/post/KWS", + "shortcode": "KWS", + "is_reply": False, + } + ] + + with patch("threads.threads_client.ThreadsClient") as MockClient, \ + patch("threads.threads_client.is_title_used", return_value=False): + instance = MockClient.return_value + instance.keyword_search.return_value = mock_results + instance.get_conversation.return_value = [] + + content = _get_keyword_search_content(500, 1) + + assert content is not None + assert "test keyword" in content["thread_title"] + + def test_returns_none_when_no_search_query(self, mock_config): + from threads.threads_client import _get_keyword_search_content + + mock_config["threads"]["thread"]["search_query"] = "" + result = _get_keyword_search_content(500, 1) + assert result is None diff --git a/tests/test_id.py b/tests/test_id.py new file mode 100644 index 0000000..f54dd1d --- /dev/null +++ b/tests/test_id.py @@ -0,0 +1,48 @@ +""" +Unit tests for utils/id.py — Thread/post ID extraction. +""" + +import pytest + +from utils.id import extract_id + + +class TestExtractId: + def test_extracts_thread_id(self): + obj = {"thread_id": "ABC123"} + assert extract_id(obj) == "ABC123" + + def test_extracts_custom_field(self): + obj = {"custom_field": "XYZ789"} + assert extract_id(obj, field="custom_field") == "XYZ789" + + def test_strips_special_characters(self): + obj = {"thread_id": "abc!@#$%^&*()123"} + result = extract_id(obj) + assert "!" not in result + assert "@" not in result + assert "#" not in result + assert "$" not in result + # Alphanumeric and hyphens/underscores/whitespace should remain + assert "abc" in result + assert "123" in result + + def test_raises_for_missing_field(self): + obj = {"other_field": "value"} + with pytest.raises(ValueError, match="Field 'thread_id' not found"): + extract_id(obj) + + def test_handles_empty_string_id(self): + obj = {"thread_id": ""} + result = extract_id(obj) + assert result == "" + + def test_preserves_hyphens_and_underscores(self): + obj = {"thread_id": "test-thread_123"} + result = extract_id(obj) + assert result == "test-thread_123" + + def test_preserves_whitespace(self): + obj = {"thread_id": "test thread 123"} + result = extract_id(obj) + assert "test thread 123" == result diff --git a/tests/test_scheduler_integration.py b/tests/test_scheduler_integration.py new file mode 100644 index 0000000..ece73a4 --- /dev/null +++ b/tests/test_scheduler_integration.py @@ -0,0 +1,121 @@ +""" +Integration tests for the scheduler pipeline flow. + +Tests run_pipeline() and run_scheduled() with all external +dependencies (API calls, TTS, video generation) mocked. +""" + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +# Pre-mock playwright and other heavy deps needed by transitive imports +_playwright_mock = MagicMock() +_playwright_mock.sync_api.sync_playwright = MagicMock +_playwright_mock.sync_api.TimeoutError = TimeoutError + + +@pytest.fixture(autouse=True) +def _mock_heavy_deps(monkeypatch): + """Mock heavy dependencies not needed for pipeline tests.""" + monkeypatch.setitem(sys.modules, "playwright", _playwright_mock) + monkeypatch.setitem(sys.modules, "playwright.sync_api", _playwright_mock.sync_api) + + # Mock video_creation submodules that may have heavy deps (moviepy, selenium, etc.) + for mod_name in [ + "video_creation.voices", + "video_creation.threads_screenshot", + "video_creation.final_video", + "video_creation.background", + ]: + if mod_name not in sys.modules: + monkeypatch.setitem(sys.modules, mod_name, MagicMock()) + + +# =================================================================== +# run_pipeline integration +# =================================================================== + + +class TestRunPipeline: + """Test the full pipeline flow with mocked internals.""" + + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + pass + + def test_pipeline_calls_steps_in_order(self, mock_config, tmp_path): + """Verify pipeline calls all steps and returns successfully.""" + call_order = [] + + mock_thread_object = { + "thread_url": "https://threads.net/test", + "thread_title": "Test Thread", + "thread_id": "test_123", + "thread_author": "@test", + "is_nsfw": False, + "thread_post": "Content", + "comments": [ + {"comment_body": "Reply", "comment_url": "", "comment_id": "r1", "comment_author": "@r"}, + ], + } + + # Imports are local inside run_pipeline, so we must mock the source modules + with patch("threads.threads_client.get_threads_posts", return_value=mock_thread_object) as mock_get_posts, \ + patch("utils.check_token.preflight_check") as mock_preflight, \ + patch("video_creation.voices.save_text_to_mp3", return_value=(30.5, 1)) as mock_tts, \ + patch("video_creation.threads_screenshot.get_screenshots_of_threads_posts") as mock_screenshots, \ + patch("video_creation.background.get_background_config", return_value={"video": "mc", "audio": "lofi"}), \ + patch("video_creation.background.download_background_video"), \ + patch("video_creation.background.download_background_audio"), \ + patch("video_creation.background.chop_background"), \ + patch("video_creation.final_video.make_final_video") as mock_final, \ + patch("scheduler.pipeline.save_title"), \ + patch("os.path.exists", return_value=False): + from scheduler.pipeline import run_pipeline + result = run_pipeline() + + mock_preflight.assert_called_once() + mock_get_posts.assert_called_once() + mock_tts.assert_called_once() + mock_screenshots.assert_called_once() + mock_final.assert_called_once() + + def test_pipeline_handles_error(self, mock_config): + """Pipeline should propagate exceptions from steps.""" + + with patch("utils.check_token.preflight_check"), \ + patch("threads.threads_client.get_threads_posts", side_effect=Exception("API error")), \ + patch("video_creation.voices.save_text_to_mp3", return_value=(0, 0)), \ + patch("video_creation.threads_screenshot.get_screenshots_of_threads_posts"), \ + patch("video_creation.background.get_background_config", return_value={}), \ + patch("video_creation.background.download_background_video"), \ + patch("video_creation.background.download_background_audio"), \ + patch("video_creation.background.chop_background"), \ + patch("video_creation.final_video.make_final_video"): + from scheduler.pipeline import run_pipeline + with pytest.raises(Exception, match="API error"): + run_pipeline() + + +# =================================================================== +# run_scheduled — scheduler configuration +# =================================================================== + + +class TestRunScheduled: + def test_scheduler_not_enabled(self, mock_config, capsys): + from scheduler.pipeline import run_scheduled + + mock_config["scheduler"]["enabled"] = False + run_scheduled() + # Should not crash, just print warning + + def test_scheduler_invalid_cron(self, mock_config, capsys): + from scheduler.pipeline import run_scheduled + + mock_config["scheduler"]["enabled"] = True + mock_config["scheduler"]["cron"] = "invalid" + run_scheduled() + # Should not crash, just print error about invalid cron diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000..2a91bc6 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,151 @@ +""" +Unit tests for utils/settings.py — Safe type casting and config validation. +""" + +import pytest + +# Import after conftest sets up sys.path +from utils.settings import _safe_type_cast, check, crawl, crawl_and_check + + +# =================================================================== +# _safe_type_cast +# =================================================================== + + +class TestSafeTypeCast: + """Tests for _safe_type_cast — replacement for eval() calls.""" + + def test_cast_int(self): + assert _safe_type_cast("int", "42") == 42 + assert _safe_type_cast("int", 42) == 42 + + def test_cast_float(self): + assert _safe_type_cast("float", "3.14") == pytest.approx(3.14) + assert _safe_type_cast("float", 3) == pytest.approx(3.0) + + def test_cast_str(self): + assert _safe_type_cast("str", 123) == "123" + assert _safe_type_cast("str", "hello") == "hello" + + def test_cast_bool_true_variants(self): + assert _safe_type_cast("bool", "true") is True + assert _safe_type_cast("bool", "True") is True + assert _safe_type_cast("bool", "1") is True + assert _safe_type_cast("bool", "yes") is True + assert _safe_type_cast("bool", 1) is True + + def test_cast_bool_false_variants(self): + assert _safe_type_cast("bool", "false") is False + assert _safe_type_cast("bool", "0") is False + assert _safe_type_cast("bool", "no") is False + assert _safe_type_cast("bool", 0) is False + + def test_cast_false_literal(self): + """The special key "False" always returns False.""" + assert _safe_type_cast("False", "anything") is False + assert _safe_type_cast("False", True) is False + + def test_unknown_type_raises(self): + with pytest.raises(ValueError, match="Unknown type"): + _safe_type_cast("list", "[1, 2]") + + def test_invalid_int_raises(self): + with pytest.raises(ValueError): + _safe_type_cast("int", "not_a_number") + + +# =================================================================== +# crawl +# =================================================================== + + +class TestCrawl: + """Tests for crawl — recursive dictionary walking.""" + + def test_flat_dict(self): + collected = [] + crawl({"a": 1, "b": 2}, func=lambda path, val: collected.append((path, val))) + assert (["a"], 1) in collected + assert (["b"], 2) in collected + + def test_nested_dict(self): + collected = [] + crawl( + {"section": {"key1": "v1", "key2": "v2"}}, + func=lambda path, val: collected.append((path, val)), + ) + assert (["section", "key1"], "v1") in collected + assert (["section", "key2"], "v2") in collected + + def test_empty_dict(self): + collected = [] + crawl({}, func=lambda path, val: collected.append((path, val))) + assert collected == [] + + +# =================================================================== +# check (with mocked handle_input to avoid interactive prompt) +# =================================================================== + + +class TestCheck: + """Tests for the check function — value validation against checks dict.""" + + def test_valid_value_passes(self): + result = check(42, {"type": "int", "nmin": 0, "nmax": 100}, "test_var") + assert result == 42 + + def test_valid_string_passes(self): + result = check("hello", {"type": "str"}, "test_var") + assert result == "hello" + + def test_valid_options(self): + result = check("dark", {"type": "str", "options": ["dark", "light"]}, "theme") + assert result == "dark" + + def test_valid_regex(self): + result = check("vi", {"type": "str", "regex": r"^[a-z]{2}$"}, "lang") + assert result == "vi" + + def test_valid_range_min(self): + result = check(5, {"type": "int", "nmin": 1, "nmax": 10}, "count") + assert result == 5 + + def test_boundary_nmin(self): + result = check(1, {"type": "int", "nmin": 1, "nmax": 10}, "count") + assert result == 1 + + def test_boundary_nmax(self): + result = check(10, {"type": "int", "nmin": 1, "nmax": 10}, "count") + assert result == 10 + + def test_string_length_check(self): + """Iterable values check len() against nmin/nmax.""" + result = check("hello", {"type": "str", "nmin": 1, "nmax": 20}, "text") + assert result == "hello" + + +# =================================================================== +# crawl_and_check +# =================================================================== + + +class TestCrawlAndCheck: + """Tests for crawl_and_check — recursive config validation.""" + + def test_creates_missing_path(self): + obj = {"section": {"key": "existing"}} + result = crawl_and_check(obj, ["section", "key"], {"type": "str"}, "test") + assert "section" in result + assert result["section"]["key"] == "existing" + + def test_preserves_existing_value(self): + obj = {"section": {"key": "existing"}} + result = crawl_and_check(obj, ["section", "key"], {"type": "str"}, "test") + assert result["section"]["key"] == "existing" + + def test_validates_nested_int(self): + obj = {"settings": {"count": 5}} + result = crawl_and_check(obj, ["settings", "count"], {"type": "int", "nmin": 1, "nmax": 10}, "count") + assert result["settings"]["count"] == 5 diff --git a/tests/test_threads_api_integration.py b/tests/test_threads_api_integration.py new file mode 100644 index 0000000..b4e90d3 --- /dev/null +++ b/tests/test_threads_api_integration.py @@ -0,0 +1,284 @@ +""" +Integration tests for Threads API external calls — mocked HTTP layer. + +Tests the full request flow through ThreadsClient including URL construction, +parameter passing, pagination, and error handling. +""" + +import json +from unittest.mock import MagicMock, call, patch + +import pytest +import requests + +from tests.conftest import MOCK_CONFIG + + +def _fake_response(status_code=200, json_data=None, headers=None): + """Build a realistic requests.Response mock.""" + resp = MagicMock(spec=requests.Response) + resp.status_code = status_code + resp.json.return_value = json_data or {} + resp.headers = headers or {} + if status_code < 400: + resp.raise_for_status = MagicMock() + else: + resp.raise_for_status = MagicMock( + side_effect=requests.HTTPError(f"{status_code}", response=resp) + ) + return resp + + +# =================================================================== +# Full request flow — GET endpoints +# =================================================================== + + +class TestThreadsAPIIntegrationGet: + """Integration tests verifying URL construction and parameter passing.""" + + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_get_user_profile_calls_correct_endpoint(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _fake_response( + 200, {"id": "123", "username": "user"} + ) + self.client.get_user_profile() + call_url = mock_get.call_args[0][0] + assert "/me" in call_url + params = mock_get.call_args[1]["params"] + assert "fields" in params + assert "id" in params["fields"] + + def test_get_user_threads_calls_correct_endpoint(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _fake_response( + 200, {"data": [{"id": "1"}], "paging": {}} + ) + self.client.get_user_threads(limit=5) + call_url = mock_get.call_args[0][0] + assert "/threads" in call_url + + def test_get_thread_replies_includes_reverse_param(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _fake_response( + 200, {"data": [], "paging": {}} + ) + self.client.get_thread_replies("t1", reverse=True) + params = mock_get.call_args[1]["params"] + assert params.get("reverse") == "true" + + def test_get_conversation_calls_conversation_endpoint(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _fake_response( + 200, {"data": [], "paging": {}} + ) + self.client.get_conversation("t1") + call_url = mock_get.call_args[0][0] + assert "/conversation" in call_url + + def test_get_thread_insights_calls_insights_endpoint(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _fake_response( + 200, {"data": [{"name": "views", "values": [{"value": 100}]}]} + ) + self.client.get_thread_insights("t1") + call_url = mock_get.call_args[0][0] + assert "/insights" in call_url + + def test_get_publishing_limit_calls_correct_endpoint(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _fake_response( + 200, {"data": [{"quota_usage": 10, "config": {"quota_total": 250}}]} + ) + self.client.get_publishing_limit() + call_url = mock_get.call_args[0][0] + assert "/threads_publishing_limit" in call_url + + def test_keyword_search_calls_correct_endpoint(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _fake_response(200, {"data": []}) + self.client.keyword_search("test query") + call_url = mock_get.call_args[0][0] + assert "/threads_keyword_search" in call_url + params = mock_get.call_args[1]["params"] + assert params["q"] == "test query" + + +# =================================================================== +# Full request flow — POST endpoints +# =================================================================== + + +class TestThreadsAPIIntegrationPost: + """Integration tests verifying POST request construction.""" + + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_create_container_sends_post(self): + with patch.object(self.client.session, "post") as mock_post: + mock_post.return_value = _fake_response(200, {"id": "c1"}) + self.client.create_container(text="Hello") + call_url = mock_post.call_args[0][0] + assert "/threads" in call_url + data = mock_post.call_args[1]["data"] + assert data["text"] == "Hello" + assert data["media_type"] == "TEXT" + + def test_publish_thread_sends_creation_id(self): + with patch.object(self.client.session, "post") as mock_post: + mock_post.return_value = _fake_response(200, {"id": "pub_1"}) + self.client.publish_thread("c1") + data = mock_post.call_args[1]["data"] + assert data["creation_id"] == "c1" + + def test_manage_reply_sends_hide_true(self): + with patch.object(self.client.session, "post") as mock_post: + mock_post.return_value = _fake_response(200, {"success": True}) + self.client.manage_reply("r1", hide=True) + call_url = mock_post.call_args[0][0] + assert "/manage_reply" in call_url + + +# =================================================================== +# create_and_publish flow +# =================================================================== + + +class TestCreateAndPublishFlow: + """Integration test for the full create → poll → publish flow.""" + + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_text_post_flow(self): + with patch.object(self.client, "create_container") as mock_create, \ + patch.object(self.client, "publish_thread") as mock_publish: + mock_create.return_value = "c1" + mock_publish.return_value = "pub_1" + result = self.client.create_and_publish(text="Hello world") + assert result == "pub_1" + mock_create.assert_called_once() + mock_publish.assert_called_once_with("c1") + + def test_image_post_polls_status(self): + with patch.object(self.client, "create_container") as mock_create, \ + patch.object(self.client, "get_container_status") as mock_status, \ + patch.object(self.client, "publish_thread") as mock_publish, \ + patch("threads.threads_client._time.sleep"): + mock_create.return_value = "c1" + mock_status.side_effect = [ + {"status": "IN_PROGRESS"}, + {"status": "FINISHED"}, + ] + mock_publish.return_value = "pub_2" + result = self.client.create_and_publish( + text="Photo", image_url="https://example.com/img.jpg" + ) + assert result == "pub_2" + assert mock_status.call_count == 2 + + def test_container_error_raises(self): + from threads.threads_client import ThreadsAPIError + + with patch.object(self.client, "create_container") as mock_create, \ + patch.object(self.client, "get_container_status") as mock_status, \ + patch("threads.threads_client._time.sleep"): + mock_create.return_value = "c1" + mock_status.return_value = { + "status": "ERROR", + "error_message": "Invalid image format", + } + with pytest.raises(ThreadsAPIError, match="lỗi"): + self.client.create_and_publish( + image_url="https://example.com/bad.jpg" + ) + + +# =================================================================== +# Token refresh integration +# =================================================================== + + +class TestTokenRefreshIntegration: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_refresh_updates_config(self, mock_config): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _fake_response( + 200, {"access_token": "refreshed_token", "expires_in": 5184000} + ) + new_token = self.client.refresh_token() + assert new_token == "refreshed_token" + assert mock_config["threads"]["creds"]["access_token"] == "refreshed_token" + + +# =================================================================== +# Pagination integration +# =================================================================== + + +class TestPaginationIntegration: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_paginated_uses_cursor(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.side_effect = [ + _fake_response(200, { + "data": [{"id": str(i)} for i in range(3)], + "paging": {"cursors": {"after": "cursor_abc"}, "next": "next_url"}, + }), + _fake_response(200, { + "data": [{"id": str(i)} for i in range(3, 5)], + "paging": {}, + }), + ] + result = self.client._get_paginated("user/threads", max_items=10) + assert len(result) == 5 + # Second call should include the cursor + second_call_params = mock_get.call_args_list[1][1]["params"] + assert second_call_params.get("after") == "cursor_abc" + + +# =================================================================== +# Error handling integration +# =================================================================== + + +class TestErrorHandlingIntegration: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_timeout_retries(self): + with patch.object(self.client.session, "get") as mock_get, \ + patch("threads.threads_client._time.sleep"): + mock_get.side_effect = [ + requests.Timeout("Request timed out"), + _fake_response(200, {"id": "ok"}), + ] + result = self.client._get("me") + assert result == {"id": "ok"} + assert mock_get.call_count == 2 diff --git a/tests/test_threads_client.py b/tests/test_threads_client.py new file mode 100644 index 0000000..486cae1 --- /dev/null +++ b/tests/test_threads_client.py @@ -0,0 +1,679 @@ +""" +Unit tests for Threads API Client (threads/threads_client.py). + +All HTTP calls are mocked — no real API requests are made. +""" + +import copy +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest +import requests + +from tests.conftest import MOCK_CONFIG + + +# =================================================================== +# Helper: Build a mock HTTP response +# =================================================================== + + +def _mock_response(status_code=200, json_data=None, headers=None): + """Create a mock requests.Response.""" + resp = MagicMock(spec=requests.Response) + resp.status_code = status_code + resp.json.return_value = json_data or {} + resp.headers = headers or {} + resp.raise_for_status = MagicMock() + if status_code >= 400: + resp.raise_for_status.side_effect = requests.HTTPError( + f"HTTP {status_code}", response=resp + ) + return resp + + +# =================================================================== +# ThreadsAPIError +# =================================================================== + + +class TestThreadsAPIError: + def test_basic_creation(self): + from threads.threads_client import ThreadsAPIError + + err = ThreadsAPIError("test error", error_type="OAuthException", error_code=401) + assert str(err) == "test error" + assert err.error_type == "OAuthException" + assert err.error_code == 401 + + def test_defaults(self): + from threads.threads_client import ThreadsAPIError + + err = ThreadsAPIError("simple error") + assert err.error_type == "" + assert err.error_code == 0 + + +# =================================================================== +# ThreadsClient._handle_api_response +# =================================================================== + + +class TestHandleApiResponse: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_success_response(self): + resp = _mock_response(200, {"data": [{"id": "1"}]}) + result = self.client._handle_api_response(resp) + assert result == {"data": [{"id": "1"}]} + + def test_401_raises_api_error(self): + from threads.threads_client import ThreadsAPIError + + resp = _mock_response(401) + with pytest.raises(ThreadsAPIError, match="401"): + self.client._handle_api_response(resp) + + def test_403_raises_api_error(self): + from threads.threads_client import ThreadsAPIError + + resp = _mock_response(403) + with pytest.raises(ThreadsAPIError, match="403"): + self.client._handle_api_response(resp) + + def test_200_with_error_body(self): + from threads.threads_client import ThreadsAPIError + + resp = _mock_response( + 200, + {"error": {"message": "Invalid token", "type": "OAuthException", "code": 190}}, + ) + with pytest.raises(ThreadsAPIError, match="Invalid token"): + self.client._handle_api_response(resp) + + +# =================================================================== +# ThreadsClient._get and _post with retry logic +# =================================================================== + + +class TestGetWithRetry: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_successful_get(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _mock_response(200, {"id": "123"}) + result = self.client._get("me", params={"fields": "id"}) + assert result == {"id": "123"} + mock_get.assert_called_once() + + def test_retries_on_connection_error(self): + with patch.object(self.client.session, "get") as mock_get, \ + patch("threads.threads_client._time.sleep"): + # Fail twice, succeed on third + mock_get.side_effect = [ + requests.ConnectionError("Connection failed"), + requests.ConnectionError("Connection failed"), + _mock_response(200, {"id": "123"}), + ] + result = self.client._get("me") + assert result == {"id": "123"} + assert mock_get.call_count == 3 + + def test_raises_after_max_retries(self): + with patch.object(self.client.session, "get") as mock_get, \ + patch("threads.threads_client._time.sleep"): + mock_get.side_effect = requests.ConnectionError("Connection failed") + with pytest.raises(requests.ConnectionError): + self.client._get("me") + assert mock_get.call_count == 3 # _MAX_RETRIES + + def test_does_not_retry_api_error(self): + from threads.threads_client import ThreadsAPIError + + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _mock_response( + 200, {"error": {"message": "Bad request", "type": "APIError", "code": 100}} + ) + with pytest.raises(ThreadsAPIError): + self.client._get("me") + assert mock_get.call_count == 1 # No retries for API errors + + +class TestPostWithRetry: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_successful_post(self): + with patch.object(self.client.session, "post") as mock_post: + mock_post.return_value = _mock_response(200, {"id": "container_123"}) + result = self.client._post("user/threads", data={"text": "Hello"}) + assert result == {"id": "container_123"} + + +# =================================================================== +# ThreadsClient._get_paginated +# =================================================================== + + +class TestGetPaginated: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_single_page(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = { + "data": [{"id": "1"}, {"id": "2"}], + "paging": {}, + } + result = self.client._get_paginated("user/threads", max_items=10) + assert len(result) == 2 + + def test_multi_page(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.side_effect = [ + { + "data": [{"id": "1"}, {"id": "2"}], + "paging": {"cursors": {"after": "cursor1"}, "next": "url"}, + }, + { + "data": [{"id": "3"}], + "paging": {}, + }, + ] + result = self.client._get_paginated("user/threads", max_items=10) + assert len(result) == 3 + + def test_respects_max_items(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = { + "data": [{"id": str(i)} for i in range(50)], + "paging": {"cursors": {"after": "c"}, "next": "url"}, + } + result = self.client._get_paginated("user/threads", max_items=5) + assert len(result) == 5 + + def test_empty_data(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = {"data": [], "paging": {}} + result = self.client._get_paginated("user/threads", max_items=10) + assert result == [] + + +# =================================================================== +# Token Management +# =================================================================== + + +class TestValidateToken: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_validate_success(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = { + "id": "123456789", + "username": "testuser", + "name": "Test User", + } + result = self.client.validate_token() + assert result["username"] == "testuser" + + def test_validate_fails_with_bad_token(self): + from threads.threads_client import ThreadsAPIError + + with patch.object(self.client, "_get") as mock_get: + mock_get.side_effect = ThreadsAPIError("Token expired") + with pytest.raises(ThreadsAPIError, match="token"): + self.client.validate_token() + + +class TestRefreshToken: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_refresh_success(self): + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _mock_response( + 200, {"access_token": "new_token_123", "token_type": "bearer", "expires_in": 5184000} + ) + new_token = self.client.refresh_token() + assert new_token == "new_token_123" + assert self.client.access_token == "new_token_123" + + def test_refresh_failure_error_body(self): + from threads.threads_client import ThreadsAPIError + + with patch.object(self.client.session, "get") as mock_get: + mock_get.return_value = _mock_response( + 200, {"error": {"message": "Token cannot be refreshed"}} + ) + with pytest.raises(ThreadsAPIError, match="refresh"): + self.client.refresh_token() + + +# =================================================================== +# Profiles API +# =================================================================== + + +class TestGetUserProfile: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_get_own_profile(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = {"id": "123", "username": "testuser"} + result = self.client.get_user_profile() + mock_get.assert_called_once() + assert result["username"] == "testuser" + + def test_get_specific_user_profile(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = {"id": "456", "username": "other_user"} + result = self.client.get_user_profile(user_id="456") + assert result["id"] == "456" + + +# =================================================================== +# Media API +# =================================================================== + + +class TestGetUserThreads: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_get_user_threads(self): + with patch.object(self.client, "_get_paginated") as mock_paginated: + mock_paginated.return_value = [{"id": "1", "text": "Hello"}] + result = self.client.get_user_threads(limit=10) + assert len(result) == 1 + assert result[0]["text"] == "Hello" + + +class TestGetThreadById: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_get_thread_details(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = {"id": "thread_1", "text": "Thread content", "has_replies": True} + result = self.client.get_thread_by_id("thread_1") + assert result["text"] == "Thread content" + + +# =================================================================== +# Reply Management +# =================================================================== + + +class TestGetThreadReplies: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_get_replies(self): + with patch.object(self.client, "_get_paginated") as mock_paginated: + mock_paginated.return_value = [ + {"id": "r1", "text": "Reply 1"}, + {"id": "r2", "text": "Reply 2"}, + ] + result = self.client.get_thread_replies("thread_1") + assert len(result) == 2 + + +class TestGetConversation: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_get_full_conversation(self): + with patch.object(self.client, "_get_paginated") as mock_paginated: + mock_paginated.return_value = [ + {"id": "r1", "text": "Reply 1"}, + {"id": "r2", "text": "Nested reply"}, + ] + result = self.client.get_conversation("thread_1") + assert len(result) == 2 + + +class TestManageReply: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_hide_reply(self): + with patch.object(self.client, "_post") as mock_post: + mock_post.return_value = {"success": True} + result = self.client.manage_reply("reply_1", hide=True) + assert result["success"] is True + mock_post.assert_called_once_with( + "reply_1/manage_reply", data={"hide": "true"} + ) + + def test_unhide_reply(self): + with patch.object(self.client, "_post") as mock_post: + mock_post.return_value = {"success": True} + self.client.manage_reply("reply_1", hide=False) + mock_post.assert_called_once_with( + "reply_1/manage_reply", data={"hide": "false"} + ) + + +# =================================================================== +# Publishing API +# =================================================================== + + +class TestCreateContainer: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_create_text_container(self): + with patch.object(self.client, "_post") as mock_post: + mock_post.return_value = {"id": "container_123"} + cid = self.client.create_container(text="Hello world") + assert cid == "container_123" + + def test_create_image_container(self): + with patch.object(self.client, "_post") as mock_post: + mock_post.return_value = {"id": "container_456"} + cid = self.client.create_container( + media_type="IMAGE", + text="Photo caption", + image_url="https://example.com/image.jpg", + ) + assert cid == "container_456" + + def test_raises_when_no_id_returned(self): + from threads.threads_client import ThreadsAPIError + + with patch.object(self.client, "_post") as mock_post: + mock_post.return_value = {} + with pytest.raises(ThreadsAPIError, match="container ID"): + self.client.create_container(text="Test") + + +class TestPublishThread: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_publish_success(self): + with patch.object(self.client, "_post") as mock_post: + mock_post.return_value = {"id": "published_thread_1"} + media_id = self.client.publish_thread("container_123") + assert media_id == "published_thread_1" + + def test_publish_no_id(self): + from threads.threads_client import ThreadsAPIError + + with patch.object(self.client, "_post") as mock_post: + mock_post.return_value = {} + with pytest.raises(ThreadsAPIError, match="media ID"): + self.client.publish_thread("container_123") + + +# =================================================================== +# Insights API +# =================================================================== + + +class TestGetThreadInsights: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_get_insights(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = { + "data": [ + {"name": "views", "values": [{"value": 1000}]}, + {"name": "likes", "values": [{"value": 50}]}, + ] + } + result = self.client.get_thread_insights("thread_1") + assert len(result) == 2 + + +class TestGetThreadEngagement: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_engagement_dict(self): + with patch.object(self.client, "get_thread_insights") as mock_insights: + mock_insights.return_value = [ + {"name": "views", "values": [{"value": 1000}]}, + {"name": "likes", "values": [{"value": 50}]}, + {"name": "replies", "values": [{"value": 10}]}, + ] + engagement = self.client.get_thread_engagement("thread_1") + assert engagement["views"] == 1000 + assert engagement["likes"] == 50 + assert engagement["replies"] == 10 + + +# =================================================================== +# Rate Limiting +# =================================================================== + + +class TestCanPublish: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_can_publish_when_quota_available(self): + with patch.object(self.client, "get_publishing_limit") as mock_limit: + mock_limit.return_value = { + "quota_usage": 10, + "config": {"quota_total": 250}, + } + assert self.client.can_publish() is True + + def test_cannot_publish_when_quota_exhausted(self): + with patch.object(self.client, "get_publishing_limit") as mock_limit: + mock_limit.return_value = { + "quota_usage": 250, + "config": {"quota_total": 250}, + } + assert self.client.can_publish() is False + + def test_optimistic_on_error(self): + from threads.threads_client import ThreadsAPIError + + with patch.object(self.client, "get_publishing_limit") as mock_limit: + mock_limit.side_effect = ThreadsAPIError("Rate limit error") + assert self.client.can_publish() is True + + +# =================================================================== +# Keyword Search API +# =================================================================== + + +class TestKeywordSearch: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_basic_search(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = { + "data": [ + {"id": "1", "text": "Search result 1"}, + {"id": "2", "text": "Search result 2"}, + ] + } + results = self.client.keyword_search("test query") + assert len(results) == 2 + + def test_empty_query_raises(self): + with pytest.raises(ValueError, match="bắt buộc"): + self.client.keyword_search("") + + def test_whitespace_query_raises(self): + with pytest.raises(ValueError, match="bắt buộc"): + self.client.keyword_search(" ") + + def test_invalid_search_type_raises(self): + with pytest.raises(ValueError, match="search_type"): + self.client.keyword_search("test", search_type="INVALID") + + def test_invalid_search_mode_raises(self): + with pytest.raises(ValueError, match="search_mode"): + self.client.keyword_search("test", search_mode="INVALID") + + def test_invalid_media_type_raises(self): + with pytest.raises(ValueError, match="media_type"): + self.client.keyword_search("test", media_type="INVALID") + + def test_invalid_limit_raises(self): + with pytest.raises(ValueError, match="limit"): + self.client.keyword_search("test", limit=0) + with pytest.raises(ValueError, match="limit"): + self.client.keyword_search("test", limit=101) + + def test_strips_at_from_username(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = {"data": []} + self.client.keyword_search("test", author_username="@testuser") + call_params = mock_get.call_args[1]["params"] + assert call_params["author_username"] == "testuser" + + def test_search_with_all_params(self): + with patch.object(self.client, "_get") as mock_get: + mock_get.return_value = {"data": [{"id": "1"}]} + results = self.client.keyword_search( + q="trending", + search_type="RECENT", + search_mode="TAG", + media_type="TEXT", + since="1700000000", + until="1700100000", + limit=50, + author_username="user", + ) + assert len(results) == 1 + + +# =================================================================== +# Client-side keyword filter +# =================================================================== + + +class TestSearchThreadsByKeyword: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + from threads.threads_client import ThreadsClient + + self.client = ThreadsClient() + + def test_filters_by_keyword(self): + threads = [ + {"id": "1", "text": "Python is great for AI"}, + {"id": "2", "text": "JavaScript frameworks"}, + {"id": "3", "text": "Learning Python basics"}, + ] + result = self.client.search_threads_by_keyword(threads, ["python"]) + assert len(result) == 2 + + def test_case_insensitive_filter(self): + threads = [{"id": "1", "text": "PYTHON Programming"}] + result = self.client.search_threads_by_keyword(threads, ["python"]) + assert len(result) == 1 + + def test_no_match(self): + threads = [{"id": "1", "text": "JavaScript only"}] + result = self.client.search_threads_by_keyword(threads, ["python"]) + assert len(result) == 0 + + def test_multiple_keywords(self): + threads = [ + {"id": "1", "text": "Python programming"}, + {"id": "2", "text": "Java development"}, + {"id": "3", "text": "Rust is fast"}, + ] + result = self.client.search_threads_by_keyword(threads, ["python", "rust"]) + assert len(result) == 2 + + +# =================================================================== +# _contains_blocked_words +# =================================================================== + + +class TestContainsBlockedWords: + def test_no_blocked_words(self, mock_config): + from threads.threads_client import _contains_blocked_words + + mock_config["threads"]["thread"]["blocked_words"] = "" + assert _contains_blocked_words("any text here") is False + + def test_detects_blocked_word(self, mock_config): + from threads.threads_client import _contains_blocked_words + + mock_config["threads"]["thread"]["blocked_words"] = "spam, scam, fake" + assert _contains_blocked_words("This is spam content") is True + + def test_case_insensitive(self, mock_config): + from threads.threads_client import _contains_blocked_words + + mock_config["threads"]["thread"]["blocked_words"] = "spam" + assert _contains_blocked_words("SPAM HERE") is True + + def test_no_match(self, mock_config): + from threads.threads_client import _contains_blocked_words + + mock_config["threads"]["thread"]["blocked_words"] = "spam, scam" + assert _contains_blocked_words("Clean text") is False diff --git a/tests/test_title_history.py b/tests/test_title_history.py new file mode 100644 index 0000000..7f99217 --- /dev/null +++ b/tests/test_title_history.py @@ -0,0 +1,173 @@ +""" +Unit tests for utils/title_history.py — Title deduplication system. +""" + +import json +import os +from unittest.mock import patch + +import pytest + +from utils.title_history import ( + TITLE_HISTORY_PATH, + _ensure_file_exists, + get_title_count, + is_title_used, + load_title_history, + save_title, +) + + +@pytest.fixture +def patched_history_path(tmp_path): + """Redirect title history to a temporary file.""" + history_file = str(tmp_path / "title_history.json") + with patch("utils.title_history.TITLE_HISTORY_PATH", history_file): + yield history_file + + +# =================================================================== +# _ensure_file_exists +# =================================================================== + + +class TestEnsureFileExists: + def test_creates_file_when_missing(self, patched_history_path): + assert not os.path.exists(patched_history_path) + _ensure_file_exists() + assert os.path.exists(patched_history_path) + with open(patched_history_path, "r", encoding="utf-8") as f: + assert json.load(f) == [] + + def test_no_op_when_file_exists(self, patched_history_path): + # Pre-create with data + os.makedirs(os.path.dirname(patched_history_path), exist_ok=True) + with open(patched_history_path, "w", encoding="utf-8") as f: + json.dump([{"title": "existing"}], f) + _ensure_file_exists() + with open(patched_history_path, "r", encoding="utf-8") as f: + data = json.load(f) + assert len(data) == 1 + assert data[0]["title"] == "existing" + + +# =================================================================== +# load_title_history +# =================================================================== + + +class TestLoadTitleHistory: + def test_returns_empty_list_on_fresh_state(self, patched_history_path): + result = load_title_history() + assert result == [] + + def test_returns_saved_data(self, patched_history_path): + os.makedirs(os.path.dirname(patched_history_path), exist_ok=True) + entries = [{"title": "Test Title", "thread_id": "123", "source": "threads", "created_at": 1000}] + with open(patched_history_path, "w", encoding="utf-8") as f: + json.dump(entries, f) + result = load_title_history() + assert len(result) == 1 + assert result[0]["title"] == "Test Title" + + def test_handles_corrupted_json(self, patched_history_path): + os.makedirs(os.path.dirname(patched_history_path), exist_ok=True) + with open(patched_history_path, "w") as f: + f.write("not valid json!!!") + result = load_title_history() + assert result == [] + + +# =================================================================== +# is_title_used +# =================================================================== + + +class TestIsTitleUsed: + def test_returns_false_for_empty_title(self, patched_history_path): + assert is_title_used("") is False + assert is_title_used(" ") is False + + def test_returns_false_when_history_empty(self, patched_history_path): + assert is_title_used("New Title") is False + + def test_returns_true_for_exact_match(self, patched_history_path): + os.makedirs(os.path.dirname(patched_history_path), exist_ok=True) + with open(patched_history_path, "w", encoding="utf-8") as f: + json.dump([{"title": "Existing Title", "thread_id": "", "source": "threads", "created_at": 1000}], f) + assert is_title_used("Existing Title") is True + + def test_case_insensitive_match(self, patched_history_path): + os.makedirs(os.path.dirname(patched_history_path), exist_ok=True) + with open(patched_history_path, "w", encoding="utf-8") as f: + json.dump([{"title": "Existing Title", "thread_id": "", "source": "threads", "created_at": 1000}], f) + assert is_title_used("existing title") is True + assert is_title_used("EXISTING TITLE") is True + + def test_strips_whitespace(self, patched_history_path): + os.makedirs(os.path.dirname(patched_history_path), exist_ok=True) + with open(patched_history_path, "w", encoding="utf-8") as f: + json.dump([{"title": "Existing Title", "thread_id": "", "source": "threads", "created_at": 1000}], f) + assert is_title_used(" Existing Title ") is True + + def test_returns_false_for_different_title(self, patched_history_path): + os.makedirs(os.path.dirname(patched_history_path), exist_ok=True) + with open(patched_history_path, "w", encoding="utf-8") as f: + json.dump([{"title": "Existing Title", "thread_id": "", "source": "threads", "created_at": 1000}], f) + assert is_title_used("Completely Different") is False + + +# =================================================================== +# save_title +# =================================================================== + + +class TestSaveTitle: + def test_save_new_title(self, patched_history_path): + save_title("New Video Title", thread_id="abc123", source="threads") + with open(patched_history_path, "r", encoding="utf-8") as f: + data = json.load(f) + assert len(data) == 1 + assert data[0]["title"] == "New Video Title" + assert data[0]["thread_id"] == "abc123" + assert data[0]["source"] == "threads" + assert "created_at" in data[0] + + def test_skip_empty_title(self, patched_history_path): + save_title("", thread_id="abc") + save_title(" ", thread_id="abc") + # File should not be created or should remain empty + if os.path.exists(patched_history_path): + with open(patched_history_path, "r", encoding="utf-8") as f: + data = json.load(f) + assert len(data) == 0 + + def test_skip_duplicate_title(self, patched_history_path): + save_title("Unique Title", thread_id="1") + save_title("Unique Title", thread_id="2") # duplicate + with open(patched_history_path, "r", encoding="utf-8") as f: + data = json.load(f) + assert len(data) == 1 + + def test_save_multiple_unique_titles(self, patched_history_path): + save_title("Title One", thread_id="1") + save_title("Title Two", thread_id="2") + save_title("Title Three", thread_id="3") + with open(patched_history_path, "r", encoding="utf-8") as f: + data = json.load(f) + assert len(data) == 3 + + +# =================================================================== +# get_title_count +# =================================================================== + + +class TestGetTitleCount: + def test_zero_on_empty(self, patched_history_path): + assert get_title_count() == 0 + + def test_correct_count(self, patched_history_path): + save_title("A", thread_id="1") + save_title("B", thread_id="2") + assert get_title_count() == 2 diff --git a/tests/test_tts.py b/tests/test_tts.py new file mode 100644 index 0000000..04646d7 --- /dev/null +++ b/tests/test_tts.py @@ -0,0 +1,137 @@ +""" +Unit tests for TTS modules — GTTS and TTSEngine. +""" + +import sys +from unittest.mock import MagicMock, patch + +import pytest + + +# Pre-mock heavy dependencies that may not be installed in test env +@pytest.fixture(autouse=True) +def _mock_tts_deps(monkeypatch): + """Mock heavy TTS dependencies.""" + # Mock gtts + mock_gtts_module = MagicMock() + mock_gtts_class = MagicMock() + mock_gtts_module.gTTS = mock_gtts_class + monkeypatch.setitem(sys.modules, "gtts", mock_gtts_module) + + # Mock numpy + monkeypatch.setitem(sys.modules, "numpy", MagicMock()) + + # Mock translators + monkeypatch.setitem(sys.modules, "translators", MagicMock()) + + # Mock moviepy and submodules + mock_moviepy = MagicMock() + monkeypatch.setitem(sys.modules, "moviepy", mock_moviepy) + monkeypatch.setitem(sys.modules, "moviepy.audio", MagicMock()) + monkeypatch.setitem(sys.modules, "moviepy.audio.AudioClip", MagicMock()) + monkeypatch.setitem(sys.modules, "moviepy.audio.fx", MagicMock()) + + # Clear cached imports to force reimport with mocks + for mod_name in list(sys.modules.keys()): + if mod_name.startswith("TTS."): + del sys.modules[mod_name] + + +# =================================================================== +# GTTS +# =================================================================== + + +class TestGTTS: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + pass + + def test_init(self): + from TTS.GTTS import GTTS + + engine = GTTS() + assert engine.max_chars == 5000 + assert engine.voices == [] + + def test_run_saves_file(self, tmp_path): + from TTS.GTTS import GTTS + + engine = GTTS() + filepath = str(tmp_path / "test.mp3") + + with patch("TTS.GTTS.gTTS") as MockGTTS: + mock_tts_instance = MagicMock() + MockGTTS.return_value = mock_tts_instance + + engine.run("Hello world", filepath) + + MockGTTS.assert_called_once_with(text="Hello world", lang="vi", slow=False) + mock_tts_instance.save.assert_called_once_with(filepath) + + def test_run_uses_config_lang(self, mock_config): + from TTS.GTTS import GTTS + + mock_config["threads"]["thread"]["post_lang"] = "en" + engine = GTTS() + + with patch("TTS.GTTS.gTTS") as MockGTTS: + MockGTTS.return_value = MagicMock() + engine.run("test", "/tmp/test.mp3") + MockGTTS.assert_called_once_with(text="test", lang="en", slow=False) + + def test_randomvoice_returns_from_list(self): + from TTS.GTTS import GTTS + + engine = GTTS() + engine.voices = ["voice1", "voice2", "voice3"] + voice = engine.randomvoice() + assert voice in engine.voices + + +# =================================================================== +# TTSEngine +# =================================================================== + + +class TestTTSEngine: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + pass + + def test_init_creates_paths(self, sample_thread_object): + from TTS.engine_wrapper import TTSEngine + + mock_module = MagicMock + engine = TTSEngine( + tts_module=mock_module, + reddit_object=sample_thread_object, + path="assets/temp/", + max_length=50, + ) + assert engine.redditid == "test_thread_123" + assert "test_thread_123/mp3" in engine.path + + def test_add_periods_removes_urls(self, sample_thread_object): + from TTS.engine_wrapper import TTSEngine + + sample_thread_object["comments"] = [ + { + "comment_body": "Check https://example.com and more\nAnother line", + "comment_id": "c1", + "comment_url": "", + "comment_author": "@user", + } + ] + + mock_module = MagicMock + engine = TTSEngine( + tts_module=mock_module, + reddit_object=sample_thread_object, + path="assets/temp/", + ) + engine.add_periods() + body = sample_thread_object["comments"][0]["comment_body"] + assert "https://" not in body + # Newlines should be replaced with ". " + assert "\n" not in body diff --git a/tests/test_upload_integration.py b/tests/test_upload_integration.py new file mode 100644 index 0000000..3374861 --- /dev/null +++ b/tests/test_upload_integration.py @@ -0,0 +1,257 @@ +""" +Integration tests for upload pipeline — verifying the UploadManager +orchestrates multi-platform uploads correctly with mocked external APIs. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from uploaders.base_uploader import VideoMetadata + + +# =================================================================== +# Full upload pipeline integration +# =================================================================== + + +class TestUploadPipelineIntegration: + """Test the full upload_to_all flow with all platforms enabled.""" + + @pytest.fixture(autouse=True) + def _setup(self, mock_config, sample_video_file): + self.video_path = sample_video_file + # Enable all uploaders + mock_config["uploaders"]["youtube"]["enabled"] = True + mock_config["uploaders"]["tiktok"]["enabled"] = True + mock_config["uploaders"]["facebook"]["enabled"] = True + + def test_all_platforms_succeed(self, mock_config): + from uploaders.upload_manager import UploadManager + + manager = UploadManager() + + # Replace all uploaders with mocks + for platform in manager.uploaders: + mock_up = MagicMock() + mock_up.safe_upload.return_value = f"https://{platform}.com/video123" + manager.uploaders[platform] = mock_up + + results = manager.upload_to_all( + video_path=self.video_path, + title="Integration Test Video", + description="Testing upload pipeline", + tags=["test"], + hashtags=["integration"], + ) + + assert len(results) == 3 + assert all(url is not None for url in results.values()) + + def test_partial_platform_failure(self, mock_config): + from uploaders.upload_manager import UploadManager + + manager = UploadManager() + + for platform in manager.uploaders: + mock_up = MagicMock() + if platform == "tiktok": + mock_up.safe_upload.return_value = None # TikTok fails + else: + mock_up.safe_upload.return_value = f"https://{platform}.com/v" + manager.uploaders[platform] = mock_up + + results = manager.upload_to_all( + video_path=self.video_path, + title="Partial Test", + ) + + assert results["tiktok"] is None + # Other platforms should still succeed + success_count = sum(1 for v in results.values() if v is not None) + assert success_count >= 1 + + def test_metadata_is_correct(self, mock_config): + from uploaders.upload_manager import UploadManager + + manager = UploadManager() + + captured_metadata = {} + for platform in manager.uploaders: + mock_up = MagicMock() + + def capture(m, name=platform): + captured_metadata[name] = m + return f"https://{name}.com/v" + + mock_up.safe_upload.side_effect = capture + manager.uploaders[platform] = mock_up + + manager.upload_to_all( + video_path=self.video_path, + title="Metadata Test", + description="Test desc", + tags=["tag1"], + hashtags=["hash1"], + privacy="private", + ) + + for name, m in captured_metadata.items(): + assert isinstance(m, VideoMetadata) + assert m.title == "Metadata Test" + assert m.description == "Test desc" + assert m.privacy == "private" + assert "hash1" in m.hashtags + + +# =================================================================== +# YouTube upload integration +# =================================================================== + + +class TestYouTubeUploadIntegration: + """Test YouTube upload flow with mocked requests.""" + + @pytest.fixture(autouse=True) + def _setup(self, mock_config, sample_video_file): + mock_config["uploaders"]["youtube"]["enabled"] = True + self.video_path = sample_video_file + + def test_full_youtube_upload_flow(self): + from uploaders.youtube_uploader import YouTubeUploader + + uploader = YouTubeUploader() + + with patch("uploaders.youtube_uploader.requests.post") as mock_post, \ + patch("uploaders.youtube_uploader.requests.put") as mock_put: + + # Auth response + auth_resp = MagicMock() + auth_resp.json.return_value = {"access_token": "yt_token"} + auth_resp.raise_for_status = MagicMock() + + # Init upload response + init_resp = MagicMock() + init_resp.headers = {"Location": "https://upload.youtube.com/session123"} + init_resp.raise_for_status = MagicMock() + + mock_post.side_effect = [auth_resp, init_resp] + + # Upload response + upload_resp = MagicMock() + upload_resp.json.return_value = {"id": "yt_video_id_123"} + upload_resp.raise_for_status = MagicMock() + mock_put.return_value = upload_resp + + uploader.authenticate() + m = VideoMetadata(file_path=self.video_path, title="YT Test") + url = uploader.upload(m) + + assert url == "https://www.youtube.com/watch?v=yt_video_id_123" + + +# =================================================================== +# TikTok upload integration +# =================================================================== + + +class TestTikTokUploadIntegration: + """Test TikTok upload flow with mocked requests.""" + + @pytest.fixture(autouse=True) + def _setup(self, mock_config, sample_video_file): + mock_config["uploaders"]["tiktok"]["enabled"] = True + self.video_path = sample_video_file + + def test_full_tiktok_upload_flow(self): + from uploaders.tiktok_uploader import TikTokUploader + + uploader = TikTokUploader() + + with patch("uploaders.tiktok_uploader.requests.post") as mock_post, \ + patch("uploaders.tiktok_uploader.requests.put") as mock_put, \ + patch("uploaders.tiktok_uploader.time.sleep"): + + # Auth response + auth_resp = MagicMock() + auth_resp.json.return_value = {"data": {"access_token": "tt_token"}} + auth_resp.raise_for_status = MagicMock() + + # Init upload response + init_resp = MagicMock() + init_resp.json.return_value = { + "data": {"publish_id": "pub_123", "upload_url": "https://upload.tiktok.com/xyz"} + } + init_resp.raise_for_status = MagicMock() + + # Status check response + status_resp = MagicMock() + status_resp.json.return_value = {"data": {"status": "PUBLISH_COMPLETE"}} + + mock_post.side_effect = [auth_resp, init_resp, status_resp] + mock_put.return_value = MagicMock(raise_for_status=MagicMock()) + + uploader.authenticate() + m = VideoMetadata(file_path=self.video_path, title="TT Test") + url = uploader.upload(m) + + assert url is not None + assert url.startswith("https://www.tiktok.com/") + + +# =================================================================== +# Facebook upload integration +# =================================================================== + + +class TestFacebookUploadIntegration: + """Test Facebook upload flow with mocked requests.""" + + @pytest.fixture(autouse=True) + def _setup(self, mock_config, sample_video_file): + mock_config["uploaders"]["facebook"]["enabled"] = True + self.video_path = sample_video_file + + def test_full_facebook_upload_flow(self): + from uploaders.facebook_uploader import FacebookUploader + + uploader = FacebookUploader() + + with patch("uploaders.facebook_uploader.requests.get") as mock_get, \ + patch("uploaders.facebook_uploader.requests.post") as mock_post: + + # Auth verify response + auth_resp = MagicMock() + auth_resp.json.return_value = {"id": "page_123", "name": "Test Page"} + auth_resp.raise_for_status = MagicMock() + mock_get.return_value = auth_resp + + # Init upload + init_resp = MagicMock() + init_resp.json.return_value = { + "upload_session_id": "sess_123", + "video_id": "vid_456", + } + init_resp.raise_for_status = MagicMock() + + # Transfer chunk + transfer_resp = MagicMock() + transfer_resp.json.return_value = { + "start_offset": str(1024), # File is 1KB, so this ends transfer + "end_offset": str(1024), + } + transfer_resp.raise_for_status = MagicMock() + + # Finish + finish_resp = MagicMock() + finish_resp.json.return_value = {"success": True} + finish_resp.raise_for_status = MagicMock() + + mock_post.side_effect = [init_resp, transfer_resp, finish_resp] + + uploader.authenticate() + m = VideoMetadata(file_path=self.video_path, title="FB Test") + url = uploader.upload(m) + + assert url is not None + assert url.startswith("https://www.facebook.com/") diff --git a/tests/test_uploaders.py b/tests/test_uploaders.py new file mode 100644 index 0000000..3126c01 --- /dev/null +++ b/tests/test_uploaders.py @@ -0,0 +1,406 @@ +""" +Unit tests for uploaders — BaseUploader, YouTubeUploader, TikTokUploader, +FacebookUploader, and UploadManager. + +All external API calls are mocked. +""" + +import os +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from uploaders.base_uploader import BaseUploader, VideoMetadata + + +# =================================================================== +# VideoMetadata +# =================================================================== + + +class TestVideoMetadata: + def test_default_values(self): + m = VideoMetadata(file_path="/tmp/video.mp4", title="Test") + assert m.file_path == "/tmp/video.mp4" + assert m.title == "Test" + assert m.description == "" + assert m.tags == [] + assert m.hashtags == [] + assert m.thumbnail_path is None + assert m.schedule_time is None + assert m.privacy == "public" + assert m.category == "Entertainment" + assert m.language == "vi" + + def test_custom_values(self): + m = VideoMetadata( + file_path="/tmp/video.mp4", + title="Custom Video", + description="Desc", + tags=["tag1"], + hashtags=["hash1"], + privacy="private", + ) + assert m.description == "Desc" + assert m.tags == ["tag1"] + assert m.privacy == "private" + + +# =================================================================== +# BaseUploader.validate_video +# =================================================================== + + +class TestBaseUploaderValidation: + """Test validate_video on a concrete subclass.""" + + def _make_uploader(self): + class ConcreteUploader(BaseUploader): + platform_name = "Test" + + def authenticate(self): + return True + + def upload(self, metadata): + return "https://example.com/video" + + return ConcreteUploader() + + def test_valid_video(self, sample_video_file): + uploader = self._make_uploader() + m = VideoMetadata(file_path=sample_video_file, title="Test Video") + assert uploader.validate_video(m) is True + + def test_missing_file(self): + uploader = self._make_uploader() + m = VideoMetadata(file_path="/nonexistent/file.mp4", title="Test") + assert uploader.validate_video(m) is False + + def test_empty_file(self, tmp_path): + empty_file = tmp_path / "empty.mp4" + empty_file.write_bytes(b"") + uploader = self._make_uploader() + m = VideoMetadata(file_path=str(empty_file), title="Test") + assert uploader.validate_video(m) is False + + def test_missing_title(self, sample_video_file): + uploader = self._make_uploader() + m = VideoMetadata(file_path=sample_video_file, title="") + assert uploader.validate_video(m) is False + + +# =================================================================== +# BaseUploader.safe_upload +# =================================================================== + + +class TestSafeUpload: + def _make_uploader(self, upload_return=None, auth_return=True): + class ConcreteUploader(BaseUploader): + platform_name = "Test" + + def authenticate(self): + self._authenticated = auth_return + return auth_return + + def upload(self, metadata): + return upload_return + + return ConcreteUploader() + + def test_successful_upload(self, sample_video_file): + uploader = self._make_uploader(upload_return="https://example.com/v1") + m = VideoMetadata(file_path=sample_video_file, title="Test Video") + result = uploader.safe_upload(m, max_retries=1) + assert result == "https://example.com/v1" + + def test_failed_auth(self, sample_video_file): + uploader = self._make_uploader(auth_return=False) + m = VideoMetadata(file_path=sample_video_file, title="Test Video") + result = uploader.safe_upload(m, max_retries=1) + assert result is None + + def test_retries_on_exception(self, sample_video_file): + class FlakeyUploader(BaseUploader): + platform_name = "Test" + _call_count = 0 + + def authenticate(self): + self._authenticated = True + return True + + def upload(self, metadata): + self._call_count += 1 + if self._call_count < 3: + raise Exception("Temporary failure") + return "https://example.com/v1" + + uploader = FlakeyUploader() + m = VideoMetadata(file_path=sample_video_file, title="Test Video") + with patch("time.sleep"): + result = uploader.safe_upload(m, max_retries=3) + assert result == "https://example.com/v1" + + def test_fails_after_max_retries(self, sample_video_file): + class AlwaysFailUploader(BaseUploader): + platform_name = "Test" + + def authenticate(self): + self._authenticated = True + return True + + def upload(self, metadata): + raise Exception("Always fails") + + uploader = AlwaysFailUploader() + m = VideoMetadata(file_path=sample_video_file, title="Test Video") + with patch("time.sleep"): + result = uploader.safe_upload(m, max_retries=2) + assert result is None + + +# =================================================================== +# YouTubeUploader +# =================================================================== + + +class TestYouTubeUploader: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + mock_config["uploaders"]["youtube"]["enabled"] = True + + def test_authenticate_success(self): + from uploaders.youtube_uploader import YouTubeUploader + + uploader = YouTubeUploader() + with patch("uploaders.youtube_uploader.requests.post") as mock_post: + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: {"access_token": "yt_token_123"}, + raise_for_status=lambda: None, + ) + assert uploader.authenticate() is True + assert uploader.access_token == "yt_token_123" + + def test_authenticate_missing_creds(self, mock_config): + mock_config["uploaders"]["youtube"]["client_id"] = "" + from uploaders.youtube_uploader import YouTubeUploader + + uploader = YouTubeUploader() + assert uploader.authenticate() is False + + def test_authenticate_api_error(self): + from uploaders.youtube_uploader import YouTubeUploader + + uploader = YouTubeUploader() + with patch("uploaders.youtube_uploader.requests.post") as mock_post: + mock_post.side_effect = Exception("Auth failed") + assert uploader.authenticate() is False + + def test_upload_returns_none_without_token(self, sample_video_file): + from uploaders.youtube_uploader import YouTubeUploader + + uploader = YouTubeUploader() + m = VideoMetadata(file_path=sample_video_file, title="Test") + assert uploader.upload(m) is None + + def test_category_id_mapping(self): + from uploaders.youtube_uploader import YouTubeUploader + + assert YouTubeUploader._get_category_id("Entertainment") == "24" + assert YouTubeUploader._get_category_id("Gaming") == "20" + assert YouTubeUploader._get_category_id("Unknown") == "24" + + def test_build_description(self): + from uploaders.youtube_uploader import YouTubeUploader + + uploader = YouTubeUploader() + m = VideoMetadata( + file_path="/tmp/v.mp4", + title="Test", + description="Video description", + hashtags=["trending", "viral"], + ) + desc = uploader._build_description(m) + assert "Video description" in desc + assert "Threads Video Maker Bot" in desc + + +# =================================================================== +# TikTokUploader +# =================================================================== + + +class TestTikTokUploader: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + mock_config["uploaders"]["tiktok"]["enabled"] = True + + def test_authenticate_success(self): + from uploaders.tiktok_uploader import TikTokUploader + + uploader = TikTokUploader() + with patch("uploaders.tiktok_uploader.requests.post") as mock_post: + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: {"data": {"access_token": "tt_token_123"}}, + raise_for_status=lambda: None, + ) + assert uploader.authenticate() is True + assert uploader.access_token == "tt_token_123" + + def test_authenticate_no_token_in_response(self): + from uploaders.tiktok_uploader import TikTokUploader + + uploader = TikTokUploader() + with patch("uploaders.tiktok_uploader.requests.post") as mock_post: + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: {"data": {}}, + raise_for_status=lambda: None, + ) + assert uploader.authenticate() is False + + def test_privacy_mapping(self): + from uploaders.tiktok_uploader import TikTokUploader + + assert TikTokUploader._map_privacy("public") == "PUBLIC_TO_EVERYONE" + assert TikTokUploader._map_privacy("private") == "SELF_ONLY" + assert TikTokUploader._map_privacy("friends") == "MUTUAL_FOLLOW_FRIENDS" + assert TikTokUploader._map_privacy("unknown") == "PUBLIC_TO_EVERYONE" + + def test_build_caption(self): + from uploaders.tiktok_uploader import TikTokUploader + + uploader = TikTokUploader() + m = VideoMetadata( + file_path="/tmp/v.mp4", + title="Test Video Title", + hashtags=["viral", "trending"], + ) + caption = uploader._build_caption(m) + assert "Test Video Title" in caption + assert "#viral" in caption + assert "#trending" in caption + + +# =================================================================== +# FacebookUploader +# =================================================================== + + +class TestFacebookUploader: + @pytest.fixture(autouse=True) + def _setup(self, mock_config): + mock_config["uploaders"]["facebook"]["enabled"] = True + + def test_authenticate_success(self): + from uploaders.facebook_uploader import FacebookUploader + + uploader = FacebookUploader() + with patch("uploaders.facebook_uploader.requests.get") as mock_get: + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: {"id": "page_123", "name": "Test Page"}, + raise_for_status=lambda: None, + ) + assert uploader.authenticate() is True + + def test_authenticate_missing_token(self, mock_config): + mock_config["uploaders"]["facebook"]["access_token"] = "" + from uploaders.facebook_uploader import FacebookUploader + + uploader = FacebookUploader() + assert uploader.authenticate() is False + + def test_authenticate_missing_page_id(self, mock_config): + mock_config["uploaders"]["facebook"]["page_id"] = "" + from uploaders.facebook_uploader import FacebookUploader + + uploader = FacebookUploader() + assert uploader.authenticate() is False + + def test_build_description(self): + from uploaders.facebook_uploader import FacebookUploader + + uploader = FacebookUploader() + m = VideoMetadata( + file_path="/tmp/v.mp4", + title="Test", + description="Some description", + hashtags=["viral"], + ) + desc = uploader._build_description(m) + assert "Some description" in desc + assert "#viral" in desc + assert "Threads Video Maker Bot" in desc + + +# =================================================================== +# UploadManager +# =================================================================== + + +class TestUploadManager: + def test_no_uploaders_when_disabled(self, mock_config): + from uploaders.upload_manager import UploadManager + + manager = UploadManager() + assert len(manager.uploaders) == 0 + + def test_upload_to_all_empty(self, mock_config, sample_video_file): + from uploaders.upload_manager import UploadManager + + manager = UploadManager() + results = manager.upload_to_all( + video_path=sample_video_file, + title="Test", + ) + assert results == {} + + def test_upload_to_platform_not_enabled(self, mock_config, sample_video_file): + from uploaders.upload_manager import UploadManager + + manager = UploadManager() + m = VideoMetadata(file_path=sample_video_file, title="Test") + result = manager.upload_to_platform("youtube", m) + assert result is None + + def test_default_hashtags(self): + from uploaders.upload_manager import UploadManager + + hashtags = UploadManager._default_hashtags() + assert "threads" in hashtags + assert "viral" in hashtags + assert "vietnam" in hashtags + + def test_init_with_enabled_uploaders(self, mock_config): + mock_config["uploaders"]["youtube"]["enabled"] = True + mock_config["uploaders"]["tiktok"]["enabled"] = True + + from uploaders.upload_manager import UploadManager + + manager = UploadManager() + assert "youtube" in manager.uploaders + assert "tiktok" in manager.uploaders + assert "facebook" not in manager.uploaders + + def test_upload_to_all_with_mocked_uploaders(self, mock_config, sample_video_file): + mock_config["uploaders"]["youtube"]["enabled"] = True + + from uploaders.upload_manager import UploadManager + + manager = UploadManager() + + # Mock the youtube uploader's safe_upload + mock_uploader = MagicMock() + mock_uploader.safe_upload.return_value = "https://youtube.com/watch?v=test" + manager.uploaders["youtube"] = mock_uploader + + results = manager.upload_to_all( + video_path=sample_video_file, + title="Test Video", + description="Test Description", + ) + assert results["youtube"] == "https://youtube.com/watch?v=test" diff --git a/tests/test_videos.py b/tests/test_videos.py new file mode 100644 index 0000000..c580a40 --- /dev/null +++ b/tests/test_videos.py @@ -0,0 +1,71 @@ +""" +Unit tests for utils/videos.py — Video deduplication and metadata storage. +""" + +import json +import os +from unittest.mock import mock_open, patch + +import pytest + + +class TestCheckDone: + def test_returns_id_when_not_done(self, mock_config, tmp_path): + from utils.videos import check_done + + videos_data = json.dumps([]) + with patch("builtins.open", mock_open(read_data=videos_data)): + result = check_done("new_thread_id") + assert result == "new_thread_id" + + def test_returns_none_when_already_done(self, mock_config, tmp_path): + from utils.videos import check_done + + videos_data = json.dumps([{"id": "existing_id", "subreddit": "test"}]) + with patch("builtins.open", mock_open(read_data=videos_data)): + result = check_done("existing_id") + assert result is None + + def test_returns_obj_when_post_id_specified(self, mock_config): + from utils.videos import check_done + + mock_config["threads"]["thread"]["post_id"] = "specific_post" + videos_data = json.dumps([{"id": "existing_id", "subreddit": "test"}]) + with patch("builtins.open", mock_open(read_data=videos_data)): + result = check_done("existing_id") + assert result == "existing_id" + + +class TestSaveData: + def test_saves_video_metadata(self, mock_config, tmp_path): + from utils.videos import save_data + + videos_file = str(tmp_path / "videos.json") + with open(videos_file, "w", encoding="utf-8") as f: + json.dump([], f) + + m = mock_open(read_data=json.dumps([])) + m.return_value.seek = lambda pos: None + + with patch("builtins.open", m): + save_data("test_channel", "output.mp4", "Test Title", "thread_123", "minecraft") + + # Verify write was called with the new data + write_calls = m().write.call_args_list + assert len(write_calls) > 0 + written_data = "".join(call.args[0] for call in write_calls) + parsed = json.loads(written_data) + assert len(parsed) == 1 + assert parsed[0]["id"] == "thread_123" + + def test_skips_duplicate_id(self, mock_config): + from utils.videos import save_data + + existing = [{"id": "thread_123", "subreddit": "test", "time": "1000", + "background_credit": "", "reddit_title": "", "filename": ""}] + m = mock_open(read_data=json.dumps(existing)) + with patch("builtins.open", m): + save_data("test_channel", "output2.mp4", "Another Title", "thread_123", "gta") + + # Verify no new data was written (duplicate ID skipped) + assert not m().write.called diff --git a/tests/test_voice.py b/tests/test_voice.py new file mode 100644 index 0000000..46dafcb --- /dev/null +++ b/tests/test_voice.py @@ -0,0 +1,147 @@ +""" +Unit tests for utils/voice.py — Text sanitization and rate-limit handling. +""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + + +# =================================================================== +# sanitize_text +# =================================================================== + + +class TestSanitizeText: + """Tests for sanitize_text — text cleaning for TTS input.""" + + @pytest.fixture(autouse=True) + def _setup_config(self, mock_config): + """Ensure settings.config is available.""" + pass + + def test_removes_urls(self): + from utils.voice import sanitize_text + + text = "Check out https://example.com and http://test.org for more info" + result = sanitize_text(text) + assert "https://" not in result + assert "http://" not in result + assert "example.com" not in result + + def test_removes_special_characters(self): + from utils.voice import sanitize_text + + text = "Hello @user! This is #awesome & great" + result = sanitize_text(text) + assert "@" not in result + assert "#" not in result + + def test_replaces_plus_and_ampersand(self): + from utils.voice import sanitize_text + + text = "1+1 equals 2" + result = sanitize_text(text) + # Verify numeric content is preserved after sanitization + assert "1" in result + assert "equals" in result + + def test_removes_extra_whitespace(self): + from utils.voice import sanitize_text + + text = "Hello world test" + result = sanitize_text(text) + assert " " not in result + + def test_preserves_normal_text(self): + from utils.voice import sanitize_text + + text = "This is a normal sentence without special characters" + result = sanitize_text(text) + # clean() with no_emojis=True may lowercase the text + # The important thing is word content is preserved + assert "normal" in result.lower() + assert "sentence" in result.lower() + assert "special" in result.lower() + + def test_handles_empty_string(self): + from utils.voice import sanitize_text + + result = sanitize_text("") + assert result == "" + + def test_handles_unicode_text(self): + from utils.voice import sanitize_text + + text = "Xin chao the gioi" + result = sanitize_text(text) + # clean() may transliterate unicode characters + assert "chao" in result.lower() or "xin" in result.lower() + + +# =================================================================== +# check_ratelimit +# =================================================================== + + +class TestCheckRateLimit: + def test_returns_true_for_normal_response(self): + from utils.voice import check_ratelimit + + mock_response = MagicMock() + mock_response.status_code = 200 + assert check_ratelimit(mock_response) is True + + def test_returns_false_for_429(self): + from utils.voice import check_ratelimit + + mock_response = MagicMock() + mock_response.status_code = 429 + mock_response.headers = {} # No rate limit header → falls to KeyError + assert check_ratelimit(mock_response) is False + + def test_handles_429_with_header(self): + import time as pytime + + from utils.voice import check_ratelimit + + mock_response = MagicMock() + mock_response.status_code = 429 + # Set reset time to just before now so sleep is tiny + mock_response.headers = {"X-RateLimit-Reset": str(int(pytime.time()) + 1)} + with patch("utils.voice.sleep") as mock_sleep: + result = check_ratelimit(mock_response) + assert result is False + + def test_returns_true_for_non_429_error(self): + from utils.voice import check_ratelimit + + mock_response = MagicMock() + mock_response.status_code = 500 + assert check_ratelimit(mock_response) is True + + +# =================================================================== +# sleep_until +# =================================================================== + + +class TestSleepUntil: + def test_raises_for_non_numeric(self): + from utils.voice import sleep_until + + with pytest.raises(Exception, match="not a number"): + sleep_until("not a timestamp") + + def test_returns_immediately_for_past_time(self): + from utils.voice import sleep_until + + # A past timestamp should return immediately without long sleep + sleep_until(0) # epoch 0 is in the past + + def test_accepts_datetime(self): + from utils.voice import sleep_until + + past_dt = datetime(2000, 1, 1) + sleep_until(past_dt) # Should return immediately