......755.
...$.*....
.664.598..
-
with open(options.filename, 'r') as input:
schematic = Schematic.fromstream(input)
- print(f"part1: {sum(schematic.parts)}")
+ partnos = [p.number for p in schematic.parts]
+ print(f"part1: {sum(partnos)}")
+
+ collection = []
+ for gear in schematic.gears:
+ if len(schematic.gears[gear]) > 1:
+ collection.append([part.number for part in schematic.gears[gear]])
+
+ result = sum([(g[0] * g[1]) for g in collection])
+ print(f"part2: {result}")
+
return 0
+from dataclasses import dataclass
from typing import List
+@dataclass
+class Position:
+ row: int = 0
+ col: int = 0
+
+
+@dataclass
+class Symbol:
+ symbol: str
+ pos: Position
+
+ @property
+ def symid(self):
+ return self.pos.row * 1000 + self.pos.col
+
+
+@dataclass
+class Part:
+ number: int
+ symbols: List[Symbol]
+
+ def __repr__(self) -> str:
+ return f"Part(number='{self.number}',symbols='{self.symbols}')"
+
+ @property
+ def is_valid(self):
+ return bool(self.symbols)
+
+ @property
+ def is_gear(self):
+ return [sym for sym in self.symbols if sym.symbol == '*']
+
+
class Schematic:
def __init__(self, rows: List[str]):
self.rows = rows
self.span = len(rows[0])
self.parts = []
+ self.gears = {}
self.debug = True
self.update()
c = self.rows[row][col]
if c.isdigit():
part['digits'].append(c)
- part['symbols'].append(self.get_neighbours(row, col))
+ part['symbols'] += self.get_neighbours(row, col)
else:
if part['digits']:
self._complete_part(part)
if part['digits']:
self._complete_part(part)
part = {'digits': [], 'symbols':[]}
+ # get gears
+ for part in self.parts:
+ if part.is_gear:
+ symid = part.is_gear[0].symid
+ self.gears.setdefault(symid, [])
+ self.gears[symid].append(part)
+ #print(self.gears)
- def _complete_part(self, part:dict):
- value = int(''.join(part['digits']))
- valid = bool(''.join(part['symbols']))
+ def _complete_part(self, data:dict):
+ value = int(''.join(data['digits']))
+ part = Part(value, data['symbols'])
if self.debug:
- print(value, valid, ''.join(part['symbols']))
- if valid:
- self.parts.append(value)
+ print(part.number, part.is_valid, part.symbols)
+ if part.is_valid:
+ self.parts.append(part)
- def get_neighbours(self, row: int, col: int) -> bool:
- result = ''
+ def get_neighbours(self, row: int, col: int) -> List[Position]:
+ result = []
for r in [row - 1, row, row + 1]:
if r >= 0 and r < len(self.rows):
for c in [col - 1, col, col + 1]:
if c >= 0 and c < self.span:
t = self.rows[r][c]
if not (t.isdigit() or t == '.'):
- result = result + t
+ result.append(Symbol(t, Position(r, c)))
return result
@staticmethod
import unittest
-import os
-import sys
from io import StringIO
from schematic import Schematic
def test_schematic_sum(self):
with open('data/test_input') as input:
schematic = Schematic.fromstream(input)
+ partnos = [p.number for p in schematic.parts]
self.assertSequenceEqual(
[467, 35, 633, 617, 592, 755, 664, 598],
- schematic.parts)
- self.assertEqual(4361, sum(schematic.parts))
+ partnos)
+ self.assertEqual(4361, sum(partnos))
def test_schematic_from_file2(self):
with open('data/test_input2') as input:
schematic = Schematic.fromstream(input)
self.assertSequenceEqual(
[467, 35, 633, 617, 592, 755, 664, 598],
- schematic.parts)
+ [p.number for p in schematic.parts])
+
+
+class TestPart2(unittest.TestCase):
+ def test_is_gear(self):
+ stream = StringIO("111...\n..*...\n..222.\n")
+ schematic = Schematic.fromstream(stream)
+ part = schematic.parts[0]
+ self.assertTrue(part.is_gear)
+
+ def test_gears(self):
+ stream = StringIO("111...\n..*...\n..222.\n")
+ schematic = Schematic.fromstream(stream)
+ print(schematic.gears)
+ self.assertEqual(1, len(schematic.gears))
+ self.assertEqual(2, len(schematic.gears[1002]))
+ self.assertEqual([111, 222], [part.number for part in schematic.gears[1002]])
+
+ def test_gear_ratio(self):
+ with open('data/test_input') as input:
+ schematic = Schematic.fromstream(input)
+ collection = []
+ for gear in schematic.gears:
+ if len(schematic.gears[gear]) > 1:
+ collection.append([part.number for part in schematic.gears[gear]])
+
+ self.assertEqual(
+ [[467, 35], [755, 598]],
+ collection)
+ self.assertEqual(
+ 467835,
+ sum([(g[0] * g[1]) for g in collection]))
+
if __name__ == '__main__':
unittest.main()