from math import sqrt


class MovieRecommender:
    def __init__(self):
        self.movie_info = {
            1210: (
                "Star Wars: Episode VI - Return of the Jedi",
                ("Action", "Adventure", "Sci-Fi"),
            ),
            2028: ("Saving Private Ryan", ("Action", "Drama", "War")),
            1307: ("When Harry Met Sally...", ("Comedy", "Romance")),
            5418: ("Bourne Identity, The", ("Action", "Mystery", "Thriller")),
            56367: ("Juno", ("Comedy", "Drama", "Romance")),
            3751: ("Chicken Run", ("Animation", "Children", "Comedy")),
        }

        self.all_user_ratings = {
            514: {2028: 5.0, 1210: 2.0},
            279: {1210: 4.0, 1307: 2.5, 56367: 0.5},
        }

    def print_all_genres(self, movie_id: int):
        pass

    def count_movies_by_genre(self, user_id: int) -> dict[str, int]:
        """Return a dictionary mapping genres to the number of movies that
        the input user has rated from that genre."""
        user_ratings = self.all_user_ratings[user_id]
        counter = {}
        for movie_id in user_ratings:
            genres = self.movie_info[movie_id][1]
            for genre in genres:
                if genre not in counter:
                    counter[genre] = 1
                else:
                    counter[genre] = counter[genre] + 1
        return counter

    @staticmethod
    def dot(a: list[float], b: list[float]) -> float:
        total = 0.0
        for i in range(len(a)):
            total += a[i] * b[i]
        return total

    # @staticmethod
    # def dot(a: dict[str, float], b: dict[str, float]) -> float:
    #     total = 0.0
    #     # TODO!
    #     return total

    @staticmethod
    def mag(a: list[float]) -> float:
        squared = map(lambda x: x * x, a)
        squared_sum = sum(squared)
        return sqrt(squared_sum)

    # @staticmethod
    # def mag(a: dict[str, float]) -> float:
    #     # TODO!


if __name__ == "__main__":
    mr = MovieRecommender()
    mr.print_all_genres(3751)
