From b04648b1e238f64bd89e7415cb958f4d91f7a686 Mon Sep 17 00:00:00 2001 From: Nahuel Lofeudo Date: Tue, 7 Mar 2023 19:38:03 +0000 Subject: [PATCH] Add initial code to migrate from the SOAP API client to GTFS-R --- gtfs_client.py | 211 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 gtfs_client.py diff --git a/gtfs_client.py b/gtfs_client.py new file mode 100644 index 0000000..cafedce --- /dev/null +++ b/gtfs_client.py @@ -0,0 +1,211 @@ +import csv +from datetime import datetime, time, timedelta +import queue +import threading +from time import mktime +from io import TextIOWrapper +import json +import os +import traceback +import urllib.request +import zipfile +from arrival_times import ArrivalTime + +# Constants and configuration +GTFS_BASE_DATA_URL = "https://www.transportforireland.ie/transitData/google_transit_combined.zip" +GTFS_R_URL = "https://api.nationaltransport.ie/gtfsr/v1?format=json" +API_KEY = '470fcdd00bfe45c188fb236757d2df4f' +BASE_DATA_MINSIZE = 20000000 # The zipped base data should be over 20-ish megabytes + +class GTFSClient: + def __init__(self, wanted_stops : list[str], update_queue: queue.Queue, update_interval_seconds: int = 60): + self.wanted_stops = wanted_stops + self._update_queue = update_queue + self._update_interval_seconds = update_interval_seconds + + # Check that the base data exists, and download it if it doesn't + base_data_file_name = GTFS_BASE_DATA_URL.split('/')[-1] + if not os.path.isfile(base_data_file_name): + try: + urllib.request.urlretrieve(GTFS_BASE_DATA_URL, base_data_file_name) + if not os.path.isfile(base_data_file_name): + raise Exception("The file %s was not downloaded.".format(base_data_file_name)) + if os.path.getsize(base_data_file_name) < BASE_DATA_MINSIZE: + raise Exception("The base data file {} was too small.".format(base_data_file_name)) + except Exception as e: + print("Error downloading base data: {}".format(str(e))) + raise e + + # Preload the entities from the base data + with zipfile.ZipFile(base_data_file_name) as zipped_base_data: + # Load stops and select the stop IDs we are interested in + print ('Loading stops...') + stops = self.loadfrom(zipped_base_data, "stops.txt", + lambda s: s['stop_name'] in wanted_stops) + self.selected_stop_ids = set([s['stop_id'] for s in stops]) + self.stops_by_stop_id = {} + for stop in stops: + self.stops_by_stop_id[stop['stop_id']] = stop + + # Load the stop times for the selected stops + print ('Loading stop times...') + self.stop_time_by_stop_and_trip_id = {} + self.selected_trip_ids = set() + stop_times = self.loadfrom(zipped_base_data, "stop_times.txt", + lambda st: st['stop_id'] in self.selected_stop_ids ) + for st in stop_times: + self.stop_time_by_stop_and_trip_id[(st['stop_id'], st['trip_id'])] = st + self.selected_trip_ids.add(st['trip_id']) + + # Load the trips that include the selected stops + print ('Loading trips...') + self.trip_by_trip_id = {} + self.selected_route_ids = set() + trips = self.loadfrom(zipped_base_data, "trips.txt", + lambda t: t['trip_id'] in self.selected_trip_ids) + for t in trips: + self.trip_by_trip_id[t['trip_id']] = t + self.selected_route_ids.add(t['route_id']) + + # Load the names of the routes for the selected trips + routes = self.loadfrom(zipped_base_data, 'routes.txt', + lambda r: r['route_id'] in self.selected_route_ids) + self.route_name_by_route_id = {} + for r in routes: + self.route_name_by_route_id[r['route_id']] = r['route_short_name'] + + # Schedule refresh + if update_interval_seconds: + self._refresh_thread = threading.Thread(target=lambda: every(self._update_interval_seconds, self.refresh)) + + + def loadfrom(self, zipfile: zipfile.ZipFile, name: str, filter: callable = None) -> map: + """ + Load a CSV file from the zip + """ + with zipfile.open(name, "r") as datafile: + if not datafile: + raise Exception('File %s is not in the zipped data'.format(name)) + if filter: + result = [] + for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig")): + if filter(r): + result.append(r) + else: + result = [r for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig"))] + return result + + + def update_schedule_from(self, gtfsr_json: str) -> list: + """ + Creates a structure with the routes and arrival times from the + preloaded information, plus the gtfsr data received from the API + """ + # Parse JSON + gtfsr_data = json.loads(gtfsr_json) + entities = gtfsr_data['Entity'] + + arrivals = [] + + for e in entities: + # Skip non-updates and invalid entries + if (e.get('IsDeleted') + or not e.get('TripUpdate') + or not e['TripUpdate'].get('Trip') + or not e['TripUpdate']['Trip'].get('TripId') in self.selected_trip_ids): + continue + + # e contains an update for a trip we are interested in. + stop_times = e['TripUpdate'].get('StopTimeUpdate') + if not stop_times: + print('A TripUpdate entry does not have StopTimeUpdate:') + print(e) + continue + + for st in stop_times: + # Skip the stops we are not interested in + if not st.get('StopId') in self.selected_stop_ids: + continue + + # We have a stop time for one of our stops. Collect all info + trip_id = e['TripUpdate']['Trip']['TripId'] + trip = self.trip_by_trip_id[trip_id] + trip_destination = trip['trip_headsign'] + if len(trip_destination.split(' - ')) > 1: + trip_destination = trip_destination.split(' - ')[1] + route_id = self.trip_by_trip_id[trip_id]['route_id'] + route_name = self.route_name_by_route_id[route_id] + stop_name = self.stops_by_stop_id[st['StopId']]['stop_name'] + stop_time = self.calculate_delta( + self.stop_time_by_stop_and_trip_id[(st['StopId'], trip_id)]['arrival_time'], + st['Arrival'].get('Delay') or 0 + ) + current_timestamp = (mktime(datetime.now().timetuple())) + due_in_seconds = stop_time - current_timestamp + + arrival_time = ArrivalTime(stop_name, route_name, trip_destination, due_in_seconds) + arrivals.append(arrival_time) + arrivals = sorted(arrivals) + return arrivals + + def refresh(self) -> None: + """ Poll for new and updated information. Queue it for display update. """ + # Retrieve the updated json + url_opener = urllib.request.URLopener() + url_opener.addheader('x-api-key', API_KEY) + response = url_opener.open(GTFS_R_URL) + gtfs_r_json = response.file.read() + arrivals = gtfs.update_schedule_from(gtfs_r_json) + self._update_queue.put(arrivals) + + def calculate_delta(self, stop_time: str, delta: int) -> datetime: + """ + Returns a unix timestamp of + """ + stop_time_parts = list(map(lambda n: int(n), stop_time.split(':'))) + initial = datetime.combine(datetime.now(), + time(stop_time_parts[0], stop_time_parts[1], stop_time_parts[2])) + adjusted = initial + timedelta(seconds = delta) + return int(mktime(adjusted.timetuple())) + + +def every(delay, task) -> None: + """ Auxilliary function to schedule updates. + Taken from https://stackoverflow.com/questions/474528/what-is-the-best-way-to-repeatedly-execute-a-function-every-x-seconds + """ + next_time = time() + delay + while True: + time.sleep(max(0, next_time - time())) + try: + task() + except Exception: + traceback.print_exc() + # in production code you might want to have this instead of course: + # logger.exception("Problem while executing repetitive task.") + # skip tasks if we are behind schedule: + next_time += (time.time() - next_time) // delay * delay + delay + + +if __name__ == "__main__": + gtfs = GTFSClient([ + "Priory Walk, stop 1114", + "College Drive, stop 2410", + "Kimmage Road Lower, stop 2438", + "Brookfield, stop 2437" + ], queue.Queue(), None) + + if True: + o = urllib.request.URLopener() + o.addheader('x-api-key', API_KEY) + r = o.open(GTFS_R_URL) + if r.code != 200: + print(r.file.read()) + exit(1) + gtfs_r_json = r.file.read() + else: + gtfs_r_json = open('example.json').read() + + arrivals = gtfs.update_schedule_from(gtfs_r_json) + print(gtfs) + +