import json
import time
from collections import defaultdict
from datetime import datetime
from pprint import pprint

import requests
from loguru import logger
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as ec
from selenium.webdriver.support.wait import WebDriverWait

USERID, USERNAME, PASSWORD = "", "", ""
def login():
    global USERID, USERNAME, PASSWORD, DRIVER
    if not USERID or not USERNAME or not PASSWORD:
        return None
    
    try:
        options = webdriver.ChromeOptions()
        options.set_capability("goog:loggingPrefs", {"performance": "ALL"})
        options.add_argument("--headless")
        driver = webdriver.Chrome(options=options)
        driver.set_page_load_timeout(30)
        driver.get("https://x.com/i/flow/login")
        
        WebDriverWait(driver, 10).until(
            ec.presence_of_element_located((By.CSS_SELECTOR, 'input[autocomplete="username"]')))
        username_field = driver.find_element(By.CSS_SELECTOR, 'input[autocomplete="username"]')
        username_field.send_keys(USERNAME)
        buttons = driver.find_elements(By.TAG_NAME, 'button')
        buttons[2].click()
        
        WebDriverWait(driver, 10).until(
            ec.presence_of_element_located((By.CSS_SELECTOR, 'input[autocomplete="on"]')))
        userid_field = driver.find_element(By.CSS_SELECTOR, 'input[autocomplete="on"]')
        if not userid_field.get_attribute("value"):
            userid_field.send_keys(USERID)
            buttons = driver.find_elements(By.TAG_NAME, 'button')
            buttons[1].click()
        
        WebDriverWait(driver, 10).until(
            ec.presence_of_element_located((By.CSS_SELECTOR, 'input[autocomplete="current-password"]')))
        password_field = driver.find_element(By.CSS_SELECTOR, 'input[autocomplete="current-password"]')
        password_field.send_keys(PASSWORD)
        login_button = driver.find_element(By.CSS_SELECTOR, 'button[data-testid="LoginForm_Login_Button"]')
        login_button.click()
        
        WebDriverWait(driver, 60).until(ec.url_contains('/home'))
        cookies = driver.get_cookies()
        cookie_string = "; ".join([f"{cookie['name']}={cookie['value']}" for cookie in cookies])
        logger.success(f"Twitter login success for username {USERNAME}\n{cookie_string}")
        DRIVER = driver
        return driver

    except Exception as e:
        logger.error(f"Twitter login failed for username {USERNAME}: {e}")
        driver.quit()
        return None


ERROR_COUNT = 0
def get_timeline(url):
    global ERROR_COUNT, DRIVER
    logger.info(f"check timeline {url}")
    try:
        driver = DRIVER
        driver.get(url)
        WebDriverWait(driver, 30).until(
            ec.presence_of_element_located((By.CSS_SELECTOR, 'div[aria-label="Timeline: List"]')))
        for packet in driver.get_log("performance"):
            message = json.loads(packet["message"])["message"]
            if (message["method"] == "Network.responseReceived" and
                "ListLatestTweetsTimeline" in message["params"]["response"]["url"]):
                request_id = message["params"]["requestId"]
                resp = driver.execute_cdp_cmd('Network.getResponseBody', {'requestId': request_id})
                logger.info(f"checked")
                ERROR_COUNT = 0
                return json.loads(resp["body"])
    except Exception as e:
        logger.error(f"check failed: {e}")
        ERROR_COUNT += 1
        if ERROR_COUNT > 5:
            driver.quit()
            login()
        return {}
    
    
    
    
def parse_timeline(data):
    entries = data["data"]["list"]["tweets_timeline"]["timeline"]["instructions"][0]["entries"]
    result = []
    for entry in entries:
        try:
            result += parse_entry(entry)
        except Exception as e:
            logger.error(f"error when parsing entry: {e} {e.args}\n{entry}")
    result.sort(key=lambda x: x["timestamp"], reverse=True)
    return result

def parse_entry(entry):
    result = []
    entry_id = entry["entryId"]
    if "list-conversation" in entry_id and not "tweet" in entry_id:
        for item in entry["content"]["items"]:
            data = parse_content(item["item"])
            if data: result.append(data)
    elif entry["content"]["__typename"] != 'TimelineTimelineCursor':
        data = parse_content(entry["content"])
        if data: result.append(data)
    return result
        
def parse_content(content):
    tweet = content["itemContent"]["tweet_results"]["result"]
    while not "rest_id" in tweet: tweet = tweet["tweet"]
    try:
        data = parse_tweet(tweet)
        if "quoted_status_result" in tweet:
            data["quoted"] = parse_tweet(tweet["quoted_status_result"]["result"])
        if "retweeted_status_result" in tweet["legacy"]:
            data["retweeted"] = parse_tweet(tweet["legacy"]["retweeted_status_result"]["result"])
        return data
    except Exception as e:
        logger.error(f"error when parsing tweet: {e} {e.args}\n{tweet}")
        return {}

def parse_media(media):
    data = {
        "url": media["media_url_https"] + "?name=orig",
        "video": ""
    }
    if media["type"] in ["video", "animated_gif"]:
        variants = [i for i in media["video_info"]["variants"] if "bitrate" in i]
        variants.sort(key=lambda x: x["bitrate"], reverse=True)
        if variants: data["video"] = variants[0]["url"]
    return data

def parse_entities(entity):
    data = {
        "text": "",
        "indices": entity["indices"]
    }
    if "name" in entity: data["text"] = "@" + entity["name"]
    if "text" in entity: data["text"] = "#" + entity["text"]
    if "display_url" in entity: data["text"] = entity["display_url"]
    return data

def parse_card(card):
    data = {}
    for v in card["legacy"]["binding_values"]:
        if "choice" in v["key"] or v["key"] in ["end_datetime_utc", "unified_card"]:
            value_name = f"{v['value']['type'].lower()}_value"
            data[v["key"]] = v['value'].get(value_name, "")
    
    photo = None
    if "unified_card" in data:
        card_data = json.loads(data["unified_card"])
        del data["unified_card"]
        try:
            for k, v in card_data["media_entities"].items():
                if "media_url_https" in v:
                    photo = {
                        "url": v["media_url_https"] + "?name=orig",
                        "video": ""
                    }
                break
        except:
            logger.error(f"error parsing unified_card {card_data}")
    
    return data, photo

def parse_tweet(tweet):
    # with open("tweet.json", "w") as f: json.dump(tweet, f)
    while not "rest_id" in tweet: tweet = tweet["tweet"]
    data = {
        "rest_id": tweet["rest_id"],
        "name": tweet["core"]["user_results"]["result"]["legacy"]["name"],
        "screen_name": tweet["core"]["user_results"]["result"]["legacy"]["screen_name"],
        "profile_image": tweet["core"]["user_results"]["result"]["legacy"]["profile_image_url_https"],
        "profile_image_shape": tweet["core"]["user_results"]["result"]["profile_image_shape"],
        "full_text": tweet["legacy"]["full_text"],
        "created_at": tweet["legacy"]["created_at"],
        "timestamp": int(datetime.strptime(tweet["legacy"]["created_at"], '%a %b %d %H:%M:%S %z %Y').timestamp()),
        "media": [],
        "entities": [],
        "quoted": {},
        "retweeted": {},
        "card": {}
    }
    data["profile_image"] = data["profile_image"].replace("_normal.", ".")
    
    for m in tweet["legacy"]["entities"].get("media", []):
        data["media"].append(parse_media(m))
    
    for e in ["user_mentions", "hashtags", "urls"]:
        for m in tweet["legacy"]["entities"].get(e, []):
            data["entities"].append(parse_entities(m))
    data["entities"].sort(key=lambda x: x["indices"][0])
    
    if "card" in tweet:
        data["card"], _photo = parse_card(tweet["card"])
        if _photo: data["media"].append(_photo)
    
    return data




LATEST_TWEET_ID_DICT = {}
LATEST_TWEET_TS_DICT = {}
def check_new_tweets(tweets, url):
    global LATEST_TWEET_ID_DICT
    
    new_tweets = []
    if url in LATEST_TWEET_ID_DICT:
        for tweet in tweets:
            if tweet["rest_id"] == LATEST_TWEET_ID_DICT[url]:
                break
            if tweet["timestamp"] < LATEST_TWEET_TS_DICT[url]:
                break
            if time.time() - tweet["timestamp"] > 1200:
                break
            new_tweets.append(tweet)
            
    LATEST_TWEET_ID_DICT[url] = tweets[0]["rest_id"]
    LATEST_TWEET_TS_DICT[url] = tweets[0]["timestamp"]
    return new_tweets

def filter_tweets(tweets, filter_list):
    
    if "only_image" in filter_list:
        tweets = [t for t in tweets if t["media"]]
        
    if "only_origin" in filter_list:
        tweets = [t for t in tweets if (not t["quoted"]) and (not t["retweeted"])]
        
    return tweets

def check_timeline(config):
    data = get_timeline(config["url"])
    if data:
        tweets = parse_timeline(data)
        new_tweets = check_new_tweets(tweets, config["url"])
        return filter_tweets(new_tweets, config["filter"])
    else:
        return []




def main(config):
    global USERID, USERNAME, PASSWORD
    USERID = config["userid"] # screenid (@后面那个)
    USERNAME = config["username"] # 登录用户名或邮箱
    PASSWORD = config["password"] # 密码
    login()
    
    check_list = config.get("check_list", [])
    check_interval = config.get("check_interval", 42)
    check_interval_slow = config.get("check_interval_slow", 600)
    slow_hours = config.get("slow_hours", [0, 1, 2, 3, 4, 5, 6])
    last_check_time = defaultdict(lambda: 0.0)
    
    while 1:
        json_data = {}
        for group_id, group_config in check_list.items():
            group_interval = group_config.get("interval", check_interval)
            
            if time.time() - last_check_time[group_id] > group_interval: 
                new_tweets = check_timeline(group_config)
                if new_tweets: 
                    json_data[group_id] = new_tweets
                last_check_time[group_id] = time.time()
                
        if json_data:
            pprint(json_data)
            try:
                resp = requests.post(config["callback_url"], json=json_data)
                logger.info(resp.content)
            except Exception as e:
                logger.error(str(e))
                
        if datetime.now().hour in slow_hours:
            time.sleep(check_interval_slow)
        else:
            time.sleep(check_interval)

if __name__ == "__main__":
    with open("config.json", 'r') as f:
        config = json.load(f)
    main(config)
    
    # with open("lovelive.json", 'r', encoding="utf8") as f: pprint(parse_timeline(json.load(f)))