from manim import *
import math

def norm(vec):
    return math.sqrt(sum([xi ** 2 for xi in vec]))

def normalize(vec):
    n = norm(vec)
    return [xi / n for xi in vec]

def plus(v1, v2):
    return [x1 + x2 for x1, x2 in zip(v1, v2)]

def minus(v1, v2):
    return [x1 - x2 for x1, x2 in zip(v1, v2)]

def mult(v, c):
    return [xi * c for xi in v]

def SegmentedLine(points):
    assert(len(points) >= 2)
    ret = []
    for p1, p2 in zip( points[:-1], points[1:] ):
        vec = minus(p2, p1)
        length = norm(vec)
        normalized = normalize(vec)
        rotated = [-normalized[1], normalized[0], 0]

        ret.append(p1)
        ret.append(plus(p1, mult(normalized, length * 1/3)))
        ret.append(plus(p1, plus(mult(normalized, length/2), mult(rotated, math.sqrt(3) * length / 6))))
        ret.append(plus(p1, mult(normalized, length * 2/3)))
    ret.append(points[-1])
    return ret

class snowflake(Scene):
    def construct(self):

        points = [[-7, -1, 0],
                  [ 7, -1, 0]]

        segment = VMobject(stroke_width=1)
        segment.set_points_as_corners(points)

        self.add(*[Dot(p) for p in points])

        self.play(Create(segment))
        self.wait(1)

        vg = VGroup()
        self.add(vg)
        for i in range(4):
            points = SegmentedLine(points)

            new_vg = VGroup(*[Dot(p, radius=0.1/(i+1)) for p in points])
            self.play(Create(new_vg))

            self.play(segment.animate.set_points_as_corners(points))

            self.wait(1)

            vg.add(*[e for e in new_vg])
            self.remove(new_vg)

        self.play(Uncreate(vg))

        segment2 = segment.copy()
        segment3 = segment.copy()
        segs = VGroup(segment, segment2, segment3)
        mid = -(math.sqrt(14**2 - 7**2))/ 2 + 1

        self.play(Rotate(segment2, 2 * PI / 3, axis=[0,0,1], about_point=[0, mid, 0]),
                  Rotate(segment3, -2 * PI / 3, axis=[0,0,1], about_point=[0, mid, 0]))
        self.play(segs.animate.scale(0.5).center())

        lines = [
                Line(segment.get_all_points()[0], segment.get_all_points()[-1], color=RED),
                Line(segment2.get_all_points()[0], segment2.get_all_points()[-1], color=RED),
                Line(segment3.get_all_points()[0], segment3.get_all_points()[-1], color=RED)
                ]

        self.play(Create(VGroup(*lines)))

        self.play(*[FadeOut(mob) for mob in self.mobjects])


if __name__ == "__main__":
    sc = snowflake()
    sc.render(True)