LeetCode 272. Closest Binary Search Tree Value II

Description

https://leetcode.com/problems/closest-binary-search-tree-value-ii/

Given the root of a binary search tree, a target value, and an integer k, return the k values in the BST that are closest to the target. You may return the answer in any order.

You are guaranteed to have only one unique set of k values in the BST that are closest to the target.

Example 1:

Input: root = [4,2,5,1,3], target = 3.714286, k = 2
Output: [4,3]

Example 2:

Input: root = [1], target = 0.000000, k = 1
Output: [1]

Constraints:

  • The number of nodes in the tree is n.
  • 1 <= k <= n <= 104.
  • 0 <= Node.val <= 109
  • -109 <= target <= 109

Follow up: Assume that the BST is balanced. Could you solve it in less than O(n) runtime (where n = total nodes)?

Python Solution

First, use in order traverse to get tree values in ascending order. Then do a binary search to find the index that can be used to insert the target value. Start from the index, find to left and to right, total k values which are closer to the target.

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def closestKValues(self, root: Optional[TreeNode], target: float, k: int) -> List[int]:
    
        
        inorder_values = [] 
        
        self.dfs(root, inorder_values)
        
        left = self.find_lower_index(inorder_values, target)
        
        right = left + 1
        
        results = []
        
        for _ in range(k):
            if self.is_left_closer(inorder_values, left, right, target):
                results.append(inorder_values[left])
                left -= 1
            else:
                results.append(inorder_values[right])
                right += 1
        
        return results
    
    def is_left_closer(self, nums, left, right, target):
        if left < 0:
            return False
        if right >= len(nums):
            return True
        
        return target - nums[left] < nums[right] - target
        
        
    def find_lower_index(self, nums, target):
        start = 0
        end = len(nums) - 1
        
        while start + 1 < end:
            mid = start + (end - start) // 2
            if nums[mid] < target:
                start = mid
            else:
                end = mid
                
        if nums[end] < target:
            return end
        
        if nums[start] < target :
            return start
        
        return -1
        
    
    def dfs(self, root, results):
        if not root:
            return 
        
        self.dfs(root.left, results)
        results.append(root.val)
        self.dfs(root.right, results)
  • Time Complexity: O(N).
  • Space Complexity: O(N).

Leave a Reply

Your email address will not be published. Required fields are marked *