1251 lines
50 KiB
Python
1251 lines
50 KiB
Python
import base64
|
||
import json
|
||
import logging
|
||
import os
|
||
import threading
|
||
import tempfile
|
||
import time
|
||
import unittest
|
||
from unittest.mock import patch
|
||
|
||
import auto_pool_maintainer as apm
|
||
|
||
|
||
class DummyResponse:
|
||
def __init__(self, status_code: int, *, text: str = "", payload=None):
|
||
self.status_code = status_code
|
||
self.text = text
|
||
self._payload = payload if payload is not None else {}
|
||
self.headers = {}
|
||
self.url = "https://auth.openai.com/email-verification"
|
||
|
||
def json(self):
|
||
if isinstance(self._payload, Exception):
|
||
raise self._payload
|
||
return self._payload
|
||
|
||
|
||
def build_test_jwt(payload: dict) -> str:
|
||
header = base64.urlsafe_b64encode(json.dumps({"alg": "none", "typ": "JWT"}).encode("utf-8")).rstrip(b"=").decode("ascii")
|
||
body = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")).rstrip(b"=").decode("ascii")
|
||
return f"{header}.{body}.signature"
|
||
|
||
|
||
class FlowHelperTests(unittest.TestCase):
|
||
def test_request_with_local_retry_writes_flow_trace_log_with_redaction(self):
|
||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||
trace_path = os.path.join(tmp_dir, "flow-trace.jsonl")
|
||
recorder = apm.FlowTraceRecorder(trace_path, reveal_sensitive=False, body_limit=512)
|
||
|
||
class FakeSession:
|
||
def __init__(self):
|
||
self.cookies = []
|
||
|
||
def post(self, url, **kwargs):
|
||
response = DummyResponse(429, text='{"error":"rate_limit"}')
|
||
response.url = url
|
||
response.headers = {
|
||
"content-type": "application/json",
|
||
"set-cookie": "session=secret-cookie",
|
||
}
|
||
return response
|
||
|
||
response, reason = apm.request_with_local_retry(
|
||
FakeSession(), # type: ignore[arg-type]
|
||
"post",
|
||
"https://auth.openai.com/api/accounts/authorize/continue",
|
||
retry_attempts=1,
|
||
error_prefix="authorize_continue_request",
|
||
flow_trace=recorder,
|
||
headers={
|
||
"Authorization": "Bearer super-secret-token",
|
||
"Cookie": "session=plain-cookie",
|
||
"x-test": "ok",
|
||
},
|
||
json={"password": "PlainPassword123", "username": "trace@example.com"},
|
||
timeout=30,
|
||
verify=False,
|
||
)
|
||
|
||
self.assertEqual(reason, "")
|
||
self.assertIsNotNone(response)
|
||
|
||
with open(trace_path, "r", encoding="utf-8") as trace_file:
|
||
events = [json.loads(line) for line in trace_file if line.strip()]
|
||
|
||
event_names = [event["event"] for event in events]
|
||
self.assertIn("http_attempt", event_names)
|
||
self.assertIn("http_response", event_names)
|
||
|
||
attempt_event = next(event for event in events if event["event"] == "http_attempt")
|
||
response_event = next(event for event in events if event["event"] == "http_response")
|
||
|
||
self.assertEqual(attempt_event["request"]["url"], "https://auth.openai.com/api/accounts/authorize/continue")
|
||
self.assertEqual(response_event["response"]["status_code"], 429)
|
||
serialized = json.dumps(events, ensure_ascii=False)
|
||
self.assertNotIn("PlainPassword123", serialized)
|
||
self.assertNotIn("super-secret-token", serialized)
|
||
self.assertNotIn("plain-cookie", serialized)
|
||
|
||
def test_build_chatgpt_session_token_result_uses_callback_code(self):
|
||
access_token = build_test_jwt(
|
||
{
|
||
"email": "jwt@example.com",
|
||
"exp": 1760000000,
|
||
"https://api.openai.com/auth": {"chatgpt_account_id": "acct_123"},
|
||
}
|
||
)
|
||
|
||
class FakeSession:
|
||
def __init__(self):
|
||
self.calls = []
|
||
|
||
def get(self, url, **kwargs):
|
||
self.calls.append((url, kwargs))
|
||
if url == "https://chatgpt.com/api/auth/callback/openai?code=oauth-code":
|
||
return DummyResponse(200)
|
||
if url == "https://chatgpt.com/api/auth/session":
|
||
return DummyResponse(200, payload={"accessToken": access_token, "user": {"email": "jwt@example.com"}})
|
||
raise AssertionError(f"unexpected url: {url}")
|
||
|
||
session = FakeSession()
|
||
|
||
result = apm.build_chatgpt_session_token_result(
|
||
session=session, # type: ignore[arg-type]
|
||
auth_code="oauth-code",
|
||
chatgpt_base="https://chatgpt.com",
|
||
)
|
||
|
||
self.assertIsNotNone(result)
|
||
self.assertEqual(result["access_token"], access_token)
|
||
self.assertEqual(result["email"], "jwt@example.com")
|
||
self.assertEqual(result["account_id"], "acct_123")
|
||
self.assertEqual(result["exp"], 1760000000)
|
||
self.assertEqual(
|
||
[call[0] for call in session.calls],
|
||
[
|
||
"https://chatgpt.com/api/auth/callback/openai?code=oauth-code",
|
||
"https://chatgpt.com/api/auth/session",
|
||
],
|
||
)
|
||
|
||
def test_build_chatgpt_session_token_result_preserves_callback_query_params(self):
|
||
access_token = build_test_jwt(
|
||
{
|
||
"email": "jwt@example.com",
|
||
"exp": 1760000000,
|
||
"https://api.openai.com/auth": {"chatgpt_account_id": "acct_123"},
|
||
}
|
||
)
|
||
|
||
class FakeSession:
|
||
def __init__(self):
|
||
self.calls = []
|
||
|
||
def get(self, url, **kwargs):
|
||
self.calls.append((url, kwargs))
|
||
if (
|
||
url
|
||
== "https://chatgpt.com/api/auth/callback/openai"
|
||
"?code=oauth-code&scope=openid+email+profile+offline_access&state=oauth-state"
|
||
):
|
||
return DummyResponse(200)
|
||
if url == "https://chatgpt.com/api/auth/session":
|
||
return DummyResponse(200, payload={"accessToken": access_token, "user": {"email": "jwt@example.com"}})
|
||
raise AssertionError(f"unexpected url: {url}")
|
||
|
||
session = FakeSession()
|
||
|
||
result = apm.build_chatgpt_session_token_result(
|
||
session=session, # type: ignore[arg-type]
|
||
auth_code="oauth-code",
|
||
callback_params={
|
||
"code": "oauth-code",
|
||
"scope": "openid email profile offline_access",
|
||
"state": "oauth-state",
|
||
},
|
||
chatgpt_base="https://chatgpt.com",
|
||
)
|
||
|
||
self.assertIsNotNone(result)
|
||
self.assertEqual(result["access_token"], access_token)
|
||
self.assertEqual(
|
||
[call[0] for call in session.calls],
|
||
[
|
||
"https://chatgpt.com/api/auth/callback/openai?code=oauth-code&scope=openid+email+profile+offline_access&state=oauth-state",
|
||
"https://chatgpt.com/api/auth/session",
|
||
],
|
||
)
|
||
|
||
def test_build_chatgpt_session_token_result_finds_nested_jwt(self):
|
||
access_token = build_test_jwt(
|
||
{
|
||
"email": "jwt@example.com",
|
||
"exp": 1760000000,
|
||
"https://api.openai.com/auth": {"chatgpt_account_id": "acct_123"},
|
||
}
|
||
)
|
||
|
||
class FakeSession:
|
||
def __init__(self):
|
||
self.calls = []
|
||
|
||
def get(self, url, **kwargs):
|
||
self.calls.append((url, kwargs))
|
||
if url == "https://chatgpt.com/api/auth/session":
|
||
return DummyResponse(
|
||
200,
|
||
payload={
|
||
"user": {"email": "jwt@example.com"},
|
||
"session": {"tokens": [{"kind": "bearer", "value": access_token}]},
|
||
},
|
||
)
|
||
raise AssertionError(f"unexpected url: {url}")
|
||
|
||
session = FakeSession()
|
||
|
||
result = apm.build_chatgpt_session_token_result(
|
||
session=session, # type: ignore[arg-type]
|
||
auth_code="",
|
||
chatgpt_base="https://chatgpt.com",
|
||
)
|
||
|
||
self.assertIsNotNone(result)
|
||
self.assertEqual(result["access_token"], access_token)
|
||
self.assertEqual(result["email"], "jwt@example.com")
|
||
|
||
def test_is_transient_flow_error(self):
|
||
self.assertTrue(apm.is_transient_flow_error("oauth_step_http_503"))
|
||
self.assertTrue(apm.is_transient_flow_error("authorize_exception:timed out"))
|
||
self.assertFalse(apm.is_transient_flow_error("email_otp_validate_http_400"))
|
||
|
||
def test_parse_otp_validate_order(self):
|
||
self.assertEqual(apm.parse_otp_validate_order("normal,sentinel"), ("normal", "sentinel"))
|
||
self.assertEqual(apm.parse_otp_validate_order("sentinel,normal"), ("sentinel", "normal"))
|
||
self.assertEqual(apm.parse_otp_validate_order("invalid"), ("normal", "sentinel"))
|
||
|
||
def test_requires_phone_verification(self):
|
||
payload = {
|
||
"page": {"type": "phone_verification"},
|
||
"continue_url": "/add-phone",
|
||
}
|
||
self.assertTrue(apm.requires_phone_verification(payload, ""))
|
||
self.assertFalse(apm.requires_phone_verification({"page": {"type": "email_otp_verification"}}, ""))
|
||
|
||
def test_resolve_loop_interval_seconds(self):
|
||
self.assertEqual(apm.resolve_loop_interval_seconds({}, None), 60.0)
|
||
self.assertEqual(apm.resolve_loop_interval_seconds({"maintainer": {"loop_interval_seconds": 12}}, None), 12.0)
|
||
self.assertEqual(apm.resolve_loop_interval_seconds({"maintainer": {"loop_interval_seconds": 1}}, None), 5.0)
|
||
self.assertEqual(apm.resolve_loop_interval_seconds({}, 8.5), 8.5)
|
||
|
||
def test_parse_loop_next_check_in_seconds_from_log_line(self):
|
||
line = "2026-03-27 21:33:42 | INFO | 循环模式休眠 60.0s 后再次检查号池"
|
||
with patch("api_server.time.time", return_value=apm.dt.datetime(2026, 3, 27, 21, 34, 0).timestamp()):
|
||
import api_server as aps
|
||
|
||
remain = aps.parse_loop_next_check_in_seconds([line])
|
||
self.assertEqual(remain, 42)
|
||
|
||
def test_api_server_run_state_read_write_and_clear(self):
|
||
import api_server as aps
|
||
|
||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||
fake_state = aps.Path(tmp_dir) / "run_state.json"
|
||
with patch.object(aps, "RUN_STATE_FILE", fake_state):
|
||
aps.save_run_state(12345, "loop")
|
||
state = aps.load_run_state()
|
||
self.assertEqual(state.get("pid"), 12345)
|
||
self.assertEqual(state.get("mode"), "loop")
|
||
aps.clear_run_state()
|
||
self.assertFalse(fake_state.exists())
|
||
|
||
def test_api_server_is_pid_running_current_process(self):
|
||
import api_server as aps
|
||
|
||
self.assertTrue(aps.is_pid_running(os.getpid()))
|
||
self.assertFalse(aps.is_pid_running(99999999))
|
||
|
||
def test_analyze_usage_status_marks_quota_and_threshold(self):
|
||
body = {
|
||
"rate_limit": {
|
||
"allowed": True,
|
||
"limit_reached": False,
|
||
"primary_window": {"used_percent": 85},
|
||
"secondary_window": {"used_percent": 99},
|
||
}
|
||
}
|
||
usage = apm.analyze_usage_status(status_code=200, body_obj=body, body_text="", used_percent_threshold=80)
|
||
self.assertEqual(usage["used_percent"], 99.0)
|
||
self.assertTrue(usage["over_threshold"])
|
||
self.assertTrue(usage["is_quota"])
|
||
self.assertFalse(usage["is_healthy"])
|
||
|
||
def test_analyze_usage_status_marks_healthy(self):
|
||
body = {
|
||
"rate_limit": {
|
||
"allowed": True,
|
||
"limit_reached": False,
|
||
"primary_window": {"used_percent": 35},
|
||
}
|
||
}
|
||
usage = apm.analyze_usage_status(status_code=200, body_obj=body, body_text="", used_percent_threshold=80)
|
||
self.assertEqual(usage["used_percent"], 35.0)
|
||
self.assertFalse(usage["over_threshold"])
|
||
self.assertFalse(usage["is_quota"])
|
||
self.assertTrue(usage["is_healthy"])
|
||
|
||
def test_decide_clean_action(self):
|
||
self.assertEqual(apm.decide_clean_action(status_code=401, disabled=False, is_quota=False, over_threshold=False), "delete")
|
||
self.assertEqual(apm.decide_clean_action(status_code=200, disabled=False, is_quota=True, over_threshold=False), "disable")
|
||
self.assertEqual(apm.decide_clean_action(status_code=200, disabled=True, is_quota=False, over_threshold=False), "enable")
|
||
self.assertEqual(apm.decide_clean_action(status_code=None, disabled=False, is_quota=False, over_threshold=False), "keep")
|
||
|
||
def test_get_candidates_count_excludes_disabled_items(self):
|
||
files = [
|
||
{"type": "codex", "disabled": False},
|
||
{"type": "codex", "disabled": True},
|
||
{"type": "codex", "disabled": "false"},
|
||
{"type": "codex", "status": "disabled"},
|
||
{"type": "claude", "disabled": False},
|
||
]
|
||
total, candidates = apm.get_candidates_count_from_files(files, "codex")
|
||
self.assertEqual(total, 5)
|
||
self.assertEqual(candidates, 2)
|
||
|
||
def test_select_probe_candidates_returns_all_when_sample_size_disabled(self):
|
||
candidates = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
|
||
selected = apm.select_probe_candidates(candidates, sample_size=0, rng=apm.random.Random(1))
|
||
self.assertEqual([item["name"] for item in selected], ["a", "b", "c"])
|
||
|
||
def test_select_probe_candidates_returns_random_subset(self):
|
||
candidates = [{"name": "a"}, {"name": "b"}, {"name": "c"}, {"name": "d"}, {"name": "e"}]
|
||
selected = apm.select_probe_candidates(candidates, sample_size=2, rng=apm.random.Random(7))
|
||
self.assertEqual([item["name"] for item in selected], ["c", "b"])
|
||
|
||
def test_run_clean_401_passes_sample_size_to_async_cleanup(self):
|
||
conf = {
|
||
"clean": {
|
||
"base_url": "https://example.test",
|
||
"token": "pw",
|
||
"sample_size": 3,
|
||
}
|
||
}
|
||
captured = {}
|
||
|
||
async def fake_run_clean_401_async(**kwargs):
|
||
captured.update(kwargs)
|
||
return {"action_total": 0}
|
||
|
||
with patch.object(apm, "aiohttp", object()), patch.object(apm, "run_clean_401_async", fake_run_clean_401_async):
|
||
result = apm.run_clean_401(conf, logging.getLogger("test-clean-sample"))
|
||
|
||
self.assertEqual(captured["sample_size"], 3)
|
||
self.assertEqual(result["action_total"], 0)
|
||
|
||
def test_mail_provider_session_reuses_same_thread_and_isolates_cross_thread(self):
|
||
provider = apm.SelfHostedMailApiProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-mail-session"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="x.test",
|
||
)
|
||
main_session_first = provider._session()
|
||
main_session_second = provider._session()
|
||
self.assertIs(main_session_first, main_session_second)
|
||
|
||
holder = {}
|
||
|
||
def worker() -> None:
|
||
holder["thread_session_first"] = provider._session()
|
||
holder["thread_session_second"] = provider._session()
|
||
|
||
t = threading.Thread(target=worker)
|
||
t.start()
|
||
t.join(timeout=3)
|
||
self.assertIn("thread_session_first", holder)
|
||
self.assertIs(holder["thread_session_first"], holder["thread_session_second"])
|
||
self.assertIsNot(main_session_first, holder["thread_session_first"])
|
||
|
||
def test_self_hosted_mail_domain_normalization_removes_leading_dot(self):
|
||
provider = apm.SelfHostedMailApiProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-self-hosted-domain"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain=".qzz.io",
|
||
)
|
||
mailbox = provider.create_mailbox()
|
||
self.assertIsNotNone(mailbox)
|
||
self.assertEqual(provider.domain, "qzz.io")
|
||
self.assertNotIn("@.", mailbox.email if mailbox else "")
|
||
|
||
def test_yyds_mail_domain_normalization_removes_leading_dot(self):
|
||
provider = apm.YYDSMailProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-yyds-domain"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain=".qzz.io",
|
||
)
|
||
self.assertEqual(provider.domain, "qzz.io")
|
||
|
||
def test_self_hosted_provider_accepts_code_without_openai_keywords(self):
|
||
provider = apm.SelfHostedMailApiProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-self-hosted-code"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="qzz.io",
|
||
)
|
||
provider._fetch_latest_email = lambda _email: { # type: ignore[method-assign]
|
||
"subject": "您的登录验证码",
|
||
"text": "验证码:123456,请在页面输入",
|
||
}
|
||
codes = provider.poll_verification_codes(
|
||
apm.Mailbox(email="u@qzz.io"),
|
||
seen_ids=set(),
|
||
)
|
||
self.assertEqual(codes, ["123456"])
|
||
|
||
def test_self_hosted_provider_logs_non_200_fetch_response(self):
|
||
logger_name = "test-self-hosted-fetch-warning"
|
||
provider = apm.SelfHostedMailApiProvider(
|
||
proxy="",
|
||
logger=logging.getLogger(logger_name),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="qzz.io",
|
||
)
|
||
|
||
class FakeResponse:
|
||
status_code = 401
|
||
text = "无效的邮箱地址凭据"
|
||
|
||
def json(self):
|
||
return {}
|
||
|
||
class FakeSession:
|
||
@staticmethod
|
||
def get(*args, **kwargs):
|
||
return FakeResponse()
|
||
|
||
provider._thread_local.session = FakeSession()
|
||
with self.assertLogs(logger_name, level="WARNING") as captured:
|
||
mail_obj = provider._fetch_latest_email("u@qzz.io")
|
||
|
||
self.assertIsNone(mail_obj)
|
||
self.assertTrue(any("401" in line and "无效的邮箱地址凭据" in line for line in captured.output))
|
||
|
||
def test_yyds_provider_accepts_code_without_openai_keywords(self):
|
||
provider = apm.YYDSMailProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-yyds-code"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="qzz.io",
|
||
)
|
||
provider._fetch_messages = lambda _token: [{"id": "m-1"}] # type: ignore[method-assign]
|
||
provider._fetch_message_detail = lambda _token, _mid: { # type: ignore[method-assign]
|
||
"subject": "邮箱验证码",
|
||
"text": "本次验证码 654321,5 分钟内有效",
|
||
}
|
||
codes = provider.poll_verification_codes(
|
||
apm.Mailbox(email="u@qzz.io", token="tkn"),
|
||
seen_ids=set(),
|
||
)
|
||
self.assertEqual(codes, ["654321"])
|
||
|
||
def test_yyds_provider_accepts_code_from_inline_message_without_detail(self):
|
||
provider = apm.YYDSMailProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-yyds-inline-code"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="qzz.io",
|
||
)
|
||
provider._fetch_messages = lambda _token: [ # type: ignore[method-assign]
|
||
{"id": "m-1", "subject": "邮箱验证码", "intro": "本次验证码 112233,5 分钟内有效"}
|
||
]
|
||
provider._fetch_message_detail = lambda _token, _mid: None # type: ignore[method-assign]
|
||
codes = provider.poll_verification_codes(
|
||
apm.Mailbox(email="u@qzz.io", token="tkn"),
|
||
seen_ids=set(),
|
||
)
|
||
self.assertEqual(codes, ["112233"])
|
||
|
||
def test_yyds_provider_normalizes_prefixed_message_id_for_detail_fetch(self):
|
||
provider = apm.YYDSMailProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-yyds-message-id"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="qzz.io",
|
||
)
|
||
provider._fetch_messages = lambda _token: [{"id": "/messages/m-1"}] # type: ignore[method-assign]
|
||
detail_call = {}
|
||
|
||
def fake_fetch_detail(_token, message_id):
|
||
detail_call["message_id"] = message_id
|
||
return {
|
||
"subject": "邮箱验证码",
|
||
"text": "本次验证码 445566,5 分钟内有效",
|
||
}
|
||
|
||
provider._fetch_message_detail = fake_fetch_detail # type: ignore[method-assign]
|
||
codes = provider.poll_verification_codes(
|
||
apm.Mailbox(email="u@qzz.io", token="tkn"),
|
||
seen_ids=set(),
|
||
)
|
||
self.assertEqual(codes, ["445566"])
|
||
self.assertEqual(detail_call.get("message_id"), "m-1")
|
||
|
||
def test_yyds_provider_fetch_messages_reads_nested_messages_array(self):
|
||
provider = apm.YYDSMailProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-yyds-nested-messages"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="qzz.io",
|
||
)
|
||
|
||
class FakeResponse:
|
||
status_code = 200
|
||
content = b"1"
|
||
|
||
@staticmethod
|
||
def json():
|
||
return {
|
||
"success": True,
|
||
"data": {
|
||
"messages": [
|
||
{"id": "m-1", "subject": "邮箱验证码", "createdAt": "2026-03-28T16:00:00Z"}
|
||
]
|
||
},
|
||
}
|
||
|
||
class FakeSession:
|
||
@staticmethod
|
||
def get(*args, **kwargs):
|
||
return FakeResponse()
|
||
|
||
provider._thread_local.session = FakeSession()
|
||
messages = provider._fetch_messages("tkn")
|
||
self.assertEqual(messages, [{"id": "m-1", "subject": "邮箱验证码", "createdAt": "2026-03-28T16:00:00Z"}])
|
||
|
||
def test_self_hosted_provider_prefers_domains_over_domain(self):
|
||
provider = apm.SelfHostedMailApiProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-self-hosted-domains-priority"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="fallback.test",
|
||
domains=["a.test", "b.test"],
|
||
failure_threshold=2,
|
||
failure_cooldown_seconds=30.0,
|
||
)
|
||
|
||
self.assertEqual(provider.domains, ["a.test", "b.test"])
|
||
mailbox = provider.create_mailbox()
|
||
self.assertIsNotNone(mailbox)
|
||
self.assertTrue((mailbox.email if mailbox else "").endswith("@a.test"))
|
||
self.assertEqual(mailbox.domain if mailbox else "", "a.test")
|
||
|
||
def test_self_hosted_provider_rotates_domains_in_order(self):
|
||
provider = apm.SelfHostedMailApiProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-self-hosted-rotate"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="fallback.test",
|
||
domains=["a.test", "b.test", "c.test"],
|
||
failure_threshold=2,
|
||
failure_cooldown_seconds=30.0,
|
||
)
|
||
|
||
first = provider.create_mailbox()
|
||
second = provider.create_mailbox()
|
||
third = provider.create_mailbox()
|
||
|
||
self.assertEqual([first.domain, second.domain, third.domain], ["a.test", "b.test", "c.test"])
|
||
|
||
def test_self_hosted_provider_skips_domain_in_cooldown(self):
|
||
provider = apm.SelfHostedMailApiProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-self-hosted-cooldown"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="fallback.test",
|
||
domains=["a.test", "b.test"],
|
||
failure_threshold=2,
|
||
failure_cooldown_seconds=60.0,
|
||
)
|
||
|
||
provider.note_domain_failure("a.test", stage="create_mailbox")
|
||
provider.note_domain_failure("a.test", stage="create_mailbox")
|
||
|
||
mailbox = provider.create_mailbox()
|
||
self.assertIsNotNone(mailbox)
|
||
self.assertEqual(mailbox.domain if mailbox else "", "b.test")
|
||
|
||
def test_self_hosted_provider_reuses_domain_after_cooldown_expires(self):
|
||
provider = apm.SelfHostedMailApiProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-self-hosted-cooldown-expire"),
|
||
api_base="https://example.test",
|
||
api_key="k",
|
||
domain="fallback.test",
|
||
domains=["a.test", "b.test"],
|
||
failure_threshold=1,
|
||
failure_cooldown_seconds=5.0,
|
||
)
|
||
|
||
provider.note_domain_failure("a.test", stage="create_mailbox")
|
||
provider.domain_cooldown_until["a.test"] = time.time() - 1
|
||
|
||
mailbox = provider.create_mailbox()
|
||
self.assertIsNotNone(mailbox)
|
||
self.assertEqual(mailbox.domain if mailbox else "", "a.test")
|
||
|
||
def test_cfmail_provider_create_mailbox_uses_next_available_domain(self):
|
||
provider = apm.CfmailProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-cfmail-provider"),
|
||
api_base="https://mail.example.com",
|
||
api_key="pw",
|
||
domain="",
|
||
domains=["a.test", "b.test"],
|
||
failure_threshold=2,
|
||
failure_cooldown_seconds=60.0,
|
||
)
|
||
|
||
provider._create_address_for_domain = lambda domain: apm.Mailbox( # type: ignore[method-assign]
|
||
email=f"oc123@{domain}",
|
||
token="jwt",
|
||
domain=domain,
|
||
failure_target=domain,
|
||
)
|
||
|
||
first = provider.create_mailbox()
|
||
second = provider.create_mailbox()
|
||
|
||
self.assertIsNotNone(first)
|
||
self.assertIsNotNone(second)
|
||
self.assertEqual((first.domain, second.domain), ("a.test", "b.test"))
|
||
|
||
def test_cfmail_provider_extracts_code_from_raw_and_metadata(self):
|
||
provider = apm.CfmailProvider(
|
||
proxy="",
|
||
logger=logging.getLogger("test-cfmail-code"),
|
||
api_base="https://mail.example.com",
|
||
api_key="pw",
|
||
domain="",
|
||
domains=["a.test"],
|
||
failure_threshold=2,
|
||
failure_cooldown_seconds=60.0,
|
||
)
|
||
provider._fetch_cfmail_messages = lambda _mailbox: [ # type: ignore[method-assign]
|
||
{
|
||
"id": "m-1",
|
||
"address": "oc123@a.test",
|
||
"raw": "Subject: Your ChatGPT code is 123456",
|
||
"metadata": {"provider": "openai"},
|
||
}
|
||
]
|
||
|
||
codes = provider.poll_verification_codes(
|
||
apm.Mailbox(
|
||
email="oc123@a.test",
|
||
token="jwt",
|
||
domain="a.test",
|
||
failure_target="a.test",
|
||
),
|
||
seen_ids=set(),
|
||
)
|
||
self.assertEqual(codes, ["123456"])
|
||
|
||
def test_build_mail_provider_supports_cfmail(self):
|
||
provider = apm.build_mail_provider(
|
||
{
|
||
"mail": {"provider": "cfmail"},
|
||
"cfmail": {
|
||
"api_base": "https://mail.example.com",
|
||
"api_key": "pw",
|
||
"domains": ["a.test", "b.test"],
|
||
},
|
||
},
|
||
proxy="",
|
||
logger=logging.getLogger("test-build-cfmail"),
|
||
)
|
||
self.assertIsInstance(provider, apm.CfmailProvider)
|
||
self.assertEqual(provider.domains, ["a.test", "b.test"])
|
||
|
||
def test_api_server_merge_cfmail_api_key_preserves_masked_entries(self):
|
||
import api_server as aps
|
||
|
||
current = {
|
||
"cfmail": {
|
||
"api_base": "https://mail.example.com",
|
||
"api_key": "secret-1",
|
||
"domains": ["a.test"],
|
||
}
|
||
}
|
||
incoming = {
|
||
"cfmail": {
|
||
"api_base": "https://mail.example.com",
|
||
"api_key": aps.MASKED_VALUE,
|
||
"domains": ["a.test"],
|
||
}
|
||
}
|
||
|
||
merged = aps.merge_config_with_sensitive_fields(current, incoming)
|
||
self.assertEqual(merged["cfmail"]["api_key"], "secret-1")
|
||
|
||
class ProtocolRegistrarTests(unittest.TestCase):
|
||
def test_protocol_registrar_defaults_to_chatgpt_web_entry_mode(self):
|
||
logger = logging.getLogger("test-registration-default-entry-mode")
|
||
registrar = apm.ProtocolRegistrar(proxy="", logger=logger, conf={})
|
||
|
||
self.assertEqual(registrar.entry_mode, "chatgpt_web")
|
||
self.assertEqual(registrar._entry_mode_candidates(), ["chatgpt_web", "direct_auth"])
|
||
|
||
def test_capture_registration_tokens_uses_consent_url_redirect_code(self):
|
||
logger = logging.getLogger("test-registration-consent-code")
|
||
registrar = apm.ProtocolRegistrar(proxy="", logger=logger, conf={"flow": {"step_retry_attempts": 1}})
|
||
access_token = build_test_jwt(
|
||
{
|
||
"email": "jwt@example.com",
|
||
"exp": 1760000000,
|
||
"https://api.openai.com/auth": {"chatgpt_account_id": "acct_123"},
|
||
}
|
||
)
|
||
|
||
class FakeSession:
|
||
def __init__(self):
|
||
self.cookies = []
|
||
self.calls = []
|
||
|
||
def get(self, url, **kwargs):
|
||
self.calls.append((url, kwargs))
|
||
if url == "https://auth.openai.com/sign-in-with-chatgpt/codex/consent":
|
||
response = DummyResponse(302)
|
||
response.headers = {
|
||
"Location": (
|
||
"http://localhost:1455/auth/callback"
|
||
"?code=oauth-consent-code"
|
||
"&scope=openid+email+profile+offline_access"
|
||
"&state=oauth-state"
|
||
)
|
||
}
|
||
response.url = url
|
||
return response
|
||
if (
|
||
url
|
||
== "https://chatgpt.com/api/auth/callback/openai"
|
||
"?code=oauth-consent-code&scope=openid+email+profile+offline_access&state=oauth-state"
|
||
):
|
||
response = DummyResponse(200)
|
||
response.url = url
|
||
return response
|
||
if url == "https://chatgpt.com/api/auth/session":
|
||
response = DummyResponse(200, payload={"accessToken": access_token, "user": {"email": "jwt@example.com"}})
|
||
response.url = url
|
||
return response
|
||
raise AssertionError(f"unexpected url: {url}")
|
||
|
||
registrar.session = FakeSession() # type: ignore[assignment]
|
||
|
||
registrar._capture_registration_tokens( # type: ignore[attr-defined]
|
||
{"continue_url": "https://auth.openai.com/sign-in-with-chatgpt/codex/consent"}
|
||
)
|
||
|
||
self.assertEqual(registrar.registration_auth_code, "oauth-consent-code")
|
||
self.assertIsNotNone(registrar.registration_tokens)
|
||
self.assertEqual(registrar.registration_tokens["access_token"], access_token)
|
||
self.assertEqual(registrar.registration_tokens["email"], "jwt@example.com")
|
||
|
||
def test_capture_registration_tokens_falls_back_to_default_consent_when_add_phone_has_no_code(self):
|
||
logger = logging.getLogger("test-registration-add-phone-fallback")
|
||
registrar = apm.ProtocolRegistrar(proxy="", logger=logger, conf={"flow": {"step_retry_attempts": 1}})
|
||
access_token = build_test_jwt(
|
||
{
|
||
"email": "jwt@example.com",
|
||
"exp": 1760000000,
|
||
"https://api.openai.com/auth": {"chatgpt_account_id": "acct_123"},
|
||
}
|
||
)
|
||
|
||
class FakeSession:
|
||
def __init__(self):
|
||
self.cookies = []
|
||
self.calls = []
|
||
self.callback_completed = False
|
||
|
||
def get(self, url, **kwargs):
|
||
self.calls.append((url, kwargs))
|
||
if url == "https://auth.openai.com/add-phone":
|
||
response = DummyResponse(200, payload={"continue_url": "https://auth.openai.com/add-phone"})
|
||
response.url = url
|
||
return response
|
||
if url == "https://auth.openai.com/sign-in-with-chatgpt/codex/consent":
|
||
response = DummyResponse(302)
|
||
response.headers = {
|
||
"Location": (
|
||
"http://localhost:1455/auth/callback"
|
||
"?code=oauth-consent-code"
|
||
"&scope=openid+email+profile+offline_access"
|
||
"&state=oauth-state"
|
||
)
|
||
}
|
||
response.url = url
|
||
return response
|
||
if (
|
||
url
|
||
== "https://chatgpt.com/api/auth/callback/openai"
|
||
"?code=oauth-consent-code&scope=openid+email+profile+offline_access&state=oauth-state"
|
||
):
|
||
self.callback_completed = True
|
||
response = DummyResponse(200)
|
||
response.url = url
|
||
return response
|
||
if url == "https://chatgpt.com/api/auth/session":
|
||
payload = {"accessToken": access_token, "user": {"email": "jwt@example.com"}} if self.callback_completed else {}
|
||
response = DummyResponse(200, payload=payload)
|
||
response.url = url
|
||
return response
|
||
raise AssertionError(f"unexpected url: {url}")
|
||
|
||
registrar.session = FakeSession() # type: ignore[assignment]
|
||
|
||
registrar._capture_registration_tokens( # type: ignore[attr-defined]
|
||
{"continue_url": "https://auth.openai.com/add-phone"}
|
||
)
|
||
|
||
self.assertEqual(registrar.registration_auth_code, "oauth-consent-code")
|
||
self.assertIsNotNone(registrar.registration_tokens)
|
||
self.assertEqual(registrar.registration_tokens["access_token"], access_token)
|
||
self.assertIn(
|
||
"https://auth.openai.com/sign-in-with-chatgpt/codex/consent",
|
||
[call[0] for call in registrar.session.calls],
|
||
)
|
||
|
||
def test_capture_registration_tokens_uses_nested_create_account_code_without_following_consent(self):
|
||
logger = logging.getLogger("test-registration-nested-create-account-code")
|
||
registrar = apm.ProtocolRegistrar(proxy="", logger=logger, conf={"flow": {"step_retry_attempts": 1}})
|
||
access_token = build_test_jwt(
|
||
{
|
||
"email": "jwt@example.com",
|
||
"exp": 1760000000,
|
||
"https://api.openai.com/auth": {"chatgpt_account_id": "acct_123"},
|
||
}
|
||
)
|
||
|
||
class FakeSession:
|
||
def __init__(self):
|
||
self.cookies = []
|
||
self.calls = []
|
||
|
||
def get(self, url, **kwargs):
|
||
self.calls.append((url, kwargs))
|
||
if (
|
||
url
|
||
== "https://chatgpt.com/api/auth/callback/openai"
|
||
"?code=oauth-create-account-code&scope=openid+email+profile+offline_access&state=oauth-state"
|
||
):
|
||
response = DummyResponse(200)
|
||
response.url = url
|
||
return response
|
||
if url == "https://chatgpt.com/api/auth/session":
|
||
response = DummyResponse(200, payload={"accessToken": access_token, "user": {"email": "jwt@example.com"}})
|
||
response.url = url
|
||
return response
|
||
raise AssertionError(f"unexpected url: {url}")
|
||
|
||
registrar.session = FakeSession() # type: ignore[assignment]
|
||
|
||
registrar._capture_registration_tokens( # type: ignore[attr-defined]
|
||
{
|
||
"continue_url": "https://auth.openai.com/add-phone",
|
||
"page": {"type": "add_phone"},
|
||
"data": {
|
||
"oauth_callback": {
|
||
"code": "oauth-create-account-code",
|
||
"scope": "openid email profile offline_access",
|
||
"state": "oauth-state",
|
||
}
|
||
},
|
||
}
|
||
)
|
||
|
||
self.assertEqual(registrar.registration_auth_code, "oauth-create-account-code")
|
||
self.assertIsNotNone(registrar.registration_tokens)
|
||
self.assertEqual(registrar.registration_tokens["access_token"], access_token)
|
||
self.assertEqual(
|
||
[call[0] for call in registrar.session.calls],
|
||
[
|
||
(
|
||
"https://chatgpt.com/api/auth/callback/openai"
|
||
"?code=oauth-create-account-code&scope=openid+email+profile+offline_access&state=oauth-state"
|
||
),
|
||
"https://chatgpt.com/api/auth/session",
|
||
],
|
||
)
|
||
|
||
def test_capture_registration_tokens_uses_session_cookie_callback_without_following_consent(self):
|
||
logger = logging.getLogger("test-registration-cookie-callback-code")
|
||
registrar = apm.ProtocolRegistrar(proxy="", logger=logger, conf={"flow": {"step_retry_attempts": 1}})
|
||
access_token = build_test_jwt(
|
||
{
|
||
"email": "jwt@example.com",
|
||
"exp": 1760000000,
|
||
"https://api.openai.com/auth": {"chatgpt_account_id": "acct_123"},
|
||
}
|
||
)
|
||
cookie_payload = base64.urlsafe_b64encode(
|
||
json.dumps(
|
||
{
|
||
"continue_url": (
|
||
"http://localhost:1455/auth/callback"
|
||
"?code=oauth-cookie-code"
|
||
"&scope=openid+email+profile+offline_access"
|
||
"&state=oauth-state"
|
||
)
|
||
}
|
||
).encode("utf-8")
|
||
).rstrip(b"=").decode("ascii")
|
||
|
||
class DummyCookie:
|
||
def __init__(self, name, value):
|
||
self.name = name
|
||
self.value = value
|
||
self.domain = ".auth.openai.com"
|
||
self.path = "/"
|
||
|
||
class FakeSession:
|
||
def __init__(self):
|
||
self.cookies = [DummyCookie("oai-client-auth-session-info", f"{cookie_payload}.sig")]
|
||
self.calls = []
|
||
|
||
def get(self, url, **kwargs):
|
||
self.calls.append((url, kwargs))
|
||
if (
|
||
url
|
||
== "https://chatgpt.com/api/auth/callback/openai"
|
||
"?code=oauth-cookie-code&scope=openid+email+profile+offline_access&state=oauth-state"
|
||
):
|
||
response = DummyResponse(200)
|
||
response.url = url
|
||
return response
|
||
if url == "https://chatgpt.com/api/auth/session":
|
||
response = DummyResponse(200, payload={"accessToken": access_token, "user": {"email": "jwt@example.com"}})
|
||
response.url = url
|
||
return response
|
||
raise AssertionError(f"unexpected url: {url}")
|
||
|
||
registrar.session = FakeSession() # type: ignore[assignment]
|
||
|
||
registrar._capture_registration_tokens( # type: ignore[attr-defined]
|
||
{
|
||
"continue_url": "https://auth.openai.com/add-phone",
|
||
"page": {"type": "add_phone"},
|
||
}
|
||
)
|
||
|
||
self.assertEqual(registrar.registration_auth_code, "oauth-cookie-code")
|
||
self.assertIsNotNone(registrar.registration_tokens)
|
||
self.assertEqual(registrar.registration_tokens["access_token"], access_token)
|
||
self.assertEqual(
|
||
[call[0] for call in registrar.session.calls],
|
||
[
|
||
(
|
||
"https://chatgpt.com/api/auth/callback/openai"
|
||
"?code=oauth-cookie-code&scope=openid+email+profile+offline_access&state=oauth-state"
|
||
),
|
||
"https://chatgpt.com/api/auth/session",
|
||
],
|
||
)
|
||
|
||
def test_step4_validate_otp_sentinel_fallback(self):
|
||
logger = logging.getLogger("test-step4")
|
||
conf = {
|
||
"flow": {
|
||
"step_retry_attempts": 1,
|
||
"register_otp_validate_order": "normal,sentinel",
|
||
}
|
||
}
|
||
registrar = apm.ProtocolRegistrar(proxy="", logger=logger, conf=conf)
|
||
registrar.sentinel_gen.generate_token = lambda *_args, **_kwargs: "token-sentinel"
|
||
|
||
captured_headers = []
|
||
|
||
def fake_post(_url, **kwargs):
|
||
captured_headers.append(kwargs.get("headers") or {})
|
||
if len(captured_headers) == 1:
|
||
return DummyResponse(400)
|
||
return DummyResponse(200)
|
||
|
||
registrar.session.post = fake_post
|
||
|
||
ok = registrar.step4_validate_otp("123456")
|
||
|
||
self.assertTrue(ok)
|
||
self.assertEqual(len(captured_headers), 2)
|
||
self.assertNotIn("openai-sentinel-token", captured_headers[0])
|
||
self.assertEqual(captured_headers[1].get("openai-sentinel-token"), "token-sentinel")
|
||
|
||
def test_register_passes_mail_poll_interval_to_provider(self):
|
||
logger = logging.getLogger("test-register-mail-poll-interval")
|
||
registrar = apm.ProtocolRegistrar(proxy="", logger=logger, conf={"flow": {"step_retry_attempts": 1}})
|
||
|
||
registrar.step0_init_oauth_session = lambda *_args, **_kwargs: True
|
||
registrar.step2_register_user = lambda *_args, **_kwargs: True
|
||
registrar.step3_send_otp = lambda *_args, **_kwargs: True
|
||
registrar.step4_validate_otp = lambda *_args, **_kwargs: True
|
||
registrar.step5_create_account = lambda *_args, **_kwargs: True
|
||
|
||
class FakeMailProvider:
|
||
provider_name = "fake"
|
||
|
||
def __init__(self):
|
||
self.called_kwargs = {}
|
||
|
||
def wait_for_verification_code(self, _mailbox, **kwargs):
|
||
self.called_kwargs = kwargs
|
||
return "123456"
|
||
|
||
provider = FakeMailProvider()
|
||
|
||
with patch("auto_pool_maintainer.time.sleep", lambda *_args, **_kwargs: None):
|
||
ok = registrar.register(
|
||
email="test@example.com",
|
||
password="pw",
|
||
client_id="cid",
|
||
redirect_uri="http://localhost/cb",
|
||
mailbox=apm.Mailbox(email="test@example.com"),
|
||
mail_provider=provider, # type: ignore[arg-type]
|
||
otp_timeout_seconds=88,
|
||
otp_poll_interval_seconds=1.25,
|
||
)
|
||
|
||
self.assertTrue(ok)
|
||
self.assertEqual(provider.called_kwargs.get("timeout"), 88)
|
||
self.assertEqual(provider.called_kwargs.get("poll_interval_seconds"), 1.25)
|
||
|
||
|
||
class RegisterOneFlowTests(unittest.TestCase):
|
||
class _FakeMailProvider:
|
||
provider_name = "fake"
|
||
|
||
@staticmethod
|
||
def create_mailbox():
|
||
return apm.Mailbox(email="fake@example.com")
|
||
|
||
class _FakeRuntime:
|
||
def __init__(self, oauth_token=None):
|
||
self.stop_event = threading.Event()
|
||
self.target_tokens = 1
|
||
self._token_count = 0
|
||
self.mail_provider = RegisterOneFlowTests._FakeMailProvider()
|
||
self.mail_provider_name = "fake"
|
||
self.logger = logging.getLogger("test-register-one")
|
||
self.proxy = ""
|
||
self.conf = {}
|
||
self.oauth_client_id = "cid"
|
||
self.oauth_redirect_uri = "http://localhost/cb"
|
||
self.mail_otp_timeout_seconds = 60
|
||
self.mail_poll_interval_seconds = 1.0
|
||
self.oauth_outer_retry_attempts = 3
|
||
self.last_oauth_failure_detail = ""
|
||
self.oauth_token = oauth_token
|
||
self.oauth_called = False
|
||
self.saved_tokens = None
|
||
self.saved_account = None
|
||
self.success_key = None
|
||
|
||
def get_token_success_count(self):
|
||
return self._token_count
|
||
|
||
def wait_for_provider_availability(self, worker_id=0):
|
||
return None
|
||
|
||
def oauth_login_with_retry(self, mailbox, password):
|
||
self.oauth_called = True
|
||
return self.oauth_token
|
||
|
||
def claim_token_slot(self):
|
||
self._token_count += 1
|
||
return True, self._token_count
|
||
|
||
def release_token_slot(self):
|
||
self._token_count = max(0, self._token_count - 1)
|
||
|
||
def save_tokens(self, email, tokens):
|
||
self.saved_tokens = tokens
|
||
return True
|
||
|
||
def save_account(self, email, password):
|
||
self.saved_account = (email, password)
|
||
|
||
def note_attempt_success(self, success_key="register_oauth_success"):
|
||
self.success_key = success_key
|
||
|
||
def note_attempt_failure(self, stage, email="", detail=""):
|
||
raise AssertionError(f"unexpected failure: stage={stage} email={email} detail={detail}")
|
||
|
||
class _FakeRegistrar:
|
||
def __init__(self, proxy, logger, conf):
|
||
self.last_failure_detail = ""
|
||
self.last_failure_stage = ""
|
||
|
||
def register(self, **kwargs):
|
||
return True
|
||
|
||
def exchange_codex_tokens(self, client_id, redirect_uri):
|
||
raise AssertionError("register_one 不应再调用 exchange_codex_tokens")
|
||
|
||
def test_register_one_calls_oauth_path(self):
|
||
fake_runtime = self._FakeRuntime(oauth_token={"access_token": "oauth-token"})
|
||
|
||
class Registrar(self._FakeRegistrar):
|
||
pass
|
||
|
||
with patch("auto_pool_maintainer.ProtocolRegistrar", Registrar), patch(
|
||
"auto_pool_maintainer.generate_random_password", lambda: "Pw123456!"
|
||
):
|
||
_, success, _, _ = apm.register_one(fake_runtime, worker_id=1)
|
||
|
||
self.assertTrue(success)
|
||
self.assertTrue(fake_runtime.oauth_called)
|
||
self.assertEqual(fake_runtime.saved_tokens, {"access_token": "oauth-token"})
|
||
self.assertEqual(fake_runtime.success_key, "register_oauth_success")
|
||
|
||
def test_register_one_prefers_registration_session_tokens(self):
|
||
class RuntimeWithoutOauth(self._FakeRuntime):
|
||
def oauth_login_with_retry(self, mailbox, password):
|
||
raise AssertionError("已有注册阶段 token 时不应再跑 OAuth 登录")
|
||
|
||
runtime = RuntimeWithoutOauth(oauth_token=None)
|
||
|
||
class Registrar(self._FakeRegistrar):
|
||
def __init__(self, proxy, logger, conf):
|
||
super().__init__(proxy, logger, conf)
|
||
self.registration_tokens = {"access_token": "session-token", "email": "fake@example.com"}
|
||
|
||
with patch("auto_pool_maintainer.ProtocolRegistrar", Registrar), patch(
|
||
"auto_pool_maintainer.generate_random_password", lambda: "Pw123456!"
|
||
):
|
||
_, success, _, _ = apm.register_one(runtime, worker_id=1)
|
||
|
||
self.assertTrue(success)
|
||
self.assertEqual(runtime.saved_tokens, {"access_token": "session-token", "email": "fake@example.com"})
|
||
self.assertEqual(runtime.success_key, "register_oauth_success")
|
||
|
||
def test_register_one_returns_fail_when_oauth_failed(self):
|
||
class RuntimeWithFailure(self._FakeRuntime):
|
||
failure_events = []
|
||
|
||
def note_attempt_failure(self, stage, email="", detail=""):
|
||
self.failure_events.append((stage, email, detail))
|
||
|
||
runtime = RuntimeWithFailure(oauth_token=None)
|
||
|
||
class Registrar(self._FakeRegistrar):
|
||
pass
|
||
|
||
with patch("auto_pool_maintainer.ProtocolRegistrar", Registrar), patch(
|
||
"auto_pool_maintainer.generate_random_password", lambda: "Pw123456!"
|
||
):
|
||
_, success, _, _ = apm.register_one(runtime, worker_id=1)
|
||
|
||
self.assertFalse(success)
|
||
self.assertTrue(runtime.oauth_called)
|
||
self.assertTrue(runtime.failure_events)
|
||
self.assertEqual(runtime.failure_events[-1][0], "oauth")
|
||
|
||
def test_register_one_create_mailbox_failure_marks_selected_domain(self):
|
||
class FakeMailProvider:
|
||
provider_name = "fake"
|
||
|
||
def __init__(self):
|
||
self.last_selected_domain = "a.test"
|
||
self.failure_calls = []
|
||
|
||
def wait_for_availability(self, worker_id=0):
|
||
return None
|
||
|
||
def create_mailbox(self):
|
||
return None
|
||
|
||
def note_domain_failure(self, domain, *, stage, detail=""):
|
||
self.failure_calls.append((domain, stage, detail))
|
||
|
||
def note_domain_success(self, domain):
|
||
return None
|
||
|
||
class FakeRuntime(self._FakeRuntime):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.mail_provider = FakeMailProvider()
|
||
self.mail_provider_name = "fake"
|
||
self.failure_events = []
|
||
|
||
def note_attempt_failure(self, stage, email="", detail=""):
|
||
self.failure_events.append((stage, email, detail))
|
||
|
||
runtime = FakeRuntime()
|
||
email, success, _, _ = apm.register_one(runtime)
|
||
|
||
self.assertIsNone(email)
|
||
self.assertFalse(success)
|
||
self.assertEqual(runtime.mail_provider.failure_calls, [("a.test", "create_mailbox", "provider=fake")])
|
||
|
||
def test_register_one_register_mail_timeout_marks_mailbox_domain(self):
|
||
class FakeMailProvider(self._FakeMailProvider):
|
||
provider_name = "fake"
|
||
|
||
def __init__(self):
|
||
self.failure_calls = []
|
||
|
||
def wait_for_availability(self, worker_id=0):
|
||
return None
|
||
|
||
@staticmethod
|
||
def create_mailbox():
|
||
return apm.Mailbox(email="fake@example.com", domain="a.test")
|
||
|
||
def note_domain_failure(self, domain, *, stage, detail=""):
|
||
self.failure_calls.append((domain, stage, detail))
|
||
|
||
def note_domain_success(self, domain):
|
||
return None
|
||
|
||
class FakeRuntime(self._FakeRuntime):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.mail_provider = FakeMailProvider()
|
||
self.mail_provider_name = "fake"
|
||
self.failure_events = []
|
||
|
||
def note_attempt_failure(self, stage, email="", detail=""):
|
||
self.failure_events.append((stage, email, detail))
|
||
|
||
class FakeRegistrar(self._FakeRegistrar):
|
||
def register(self, **kwargs):
|
||
self.last_failure_stage = "register_mail_otp_timeout"
|
||
self.last_failure_detail = "provider=fake"
|
||
return False
|
||
|
||
runtime = FakeRuntime()
|
||
with patch("auto_pool_maintainer.ProtocolRegistrar", FakeRegistrar), patch(
|
||
"auto_pool_maintainer.generate_random_password", lambda: "Pw123456!"
|
||
):
|
||
email, success, _, _ = apm.register_one(runtime, worker_id=1)
|
||
|
||
self.assertEqual(email, "fake@example.com")
|
||
self.assertFalse(success)
|
||
self.assertEqual(runtime.mail_provider.failure_calls, [("a.test", "register", "provider=fake")])
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|