diff --git a/arrival_times.py b/arrival_times.py index 114869a..59ab922 100644 --- a/arrival_times.py +++ b/arrival_times.py @@ -1,7 +1,7 @@ class ArrivalTime(): """ Represents the arrival times of buses at one of the configured stops """ - def __init__(self, stop_id: int, route_id: str, destination: str, due_in_seconds: int) -> None: + def __init__(self, stop_id: str, route_id: str, destination: str, due_in_seconds: int) -> None: self.stop_id = stop_id self.route_id = route_id self.destination = destination diff --git a/gtfs_client.py b/gtfs_client.py index 73deca8..d8c8992 100644 --- a/gtfs_client.py +++ b/gtfs_client.py @@ -1,173 +1,137 @@ -import csv -from datetime import datetime, time, timedelta -import queue -import threading -from time import mktime -from io import TextIOWrapper -import json -import os -import traceback -import urllib.request -import zipfile from arrival_times import ArrivalTime +import datetime +import gtfs_kit as gk +import pandas as pd +import queue +import time +import threading +import traceback -# Constants and configuration -GTFS_BASE_DATA_URL = "https://www.transportforireland.ie/transitData/google_transit_combined.zip" -GTFS_R_URL = "https://api.nationaltransport.ie/gtfsr/v1?format=json" -API_KEY = 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX' -BASE_DATA_MINSIZE = 20000000 # The zipped base data should be over 20-ish megabytes - -class GTFSClient: - def __init__(self, wanted_stops : list[str], update_queue: queue.Queue, update_interval_seconds: int = 60): - self.wanted_stops = wanted_stops - self._update_queue = update_queue - self._update_interval_seconds = update_interval_seconds - - # Check that the base data exists, and download it if it doesn't - base_data_file_name = GTFS_BASE_DATA_URL.split('/')[-1] - if not os.path.isfile(base_data_file_name): - try: - urllib.request.urlretrieve(GTFS_BASE_DATA_URL, base_data_file_name) - if not os.path.isfile(base_data_file_name): - raise Exception("The file %s was not downloaded.".format(base_data_file_name)) - if os.path.getsize(base_data_file_name) < BASE_DATA_MINSIZE: - raise Exception("The base data file {} was too small.".format(base_data_file_name)) - except Exception as e: - print("Error downloading base data: {}".format(str(e))) - raise e - - # Preload the entities from the base data - with zipfile.ZipFile(base_data_file_name) as zipped_base_data: - # Load stops and select the stop IDs we are interested in - print ('Loading stops...') - stops = self.loadfrom(zipped_base_data, "stops.txt", - lambda s: s['stop_name'] in wanted_stops) - self.selected_stop_ids = set([s['stop_id'] for s in stops]) - self.stops_by_stop_id = {} - for stop in stops: - self.stops_by_stop_id[stop['stop_id']] = stop - - # Load the stop times for the selected stops - print ('Loading stop times...') - self.stop_time_by_stop_and_trip_id = {} - self.selected_trip_ids = set() - stop_times = self.loadfrom(zipped_base_data, "stop_times.txt", - lambda st: st['stop_id'] in self.selected_stop_ids ) - for st in stop_times: - self.stop_time_by_stop_and_trip_id[(st['stop_id'], st['trip_id'])] = st - self.selected_trip_ids.add(st['trip_id']) - - # Load the trips that include the selected stops - print ('Loading trips...') - self.trip_by_trip_id = {} - self.selected_route_ids = set() - trips = self.loadfrom(zipped_base_data, "trips.txt", - lambda t: t['trip_id'] in self.selected_trip_ids) - for t in trips: - self.trip_by_trip_id[t['trip_id']] = t - self.selected_route_ids.add(t['route_id']) - - # Load the names of the routes for the selected trips - routes = self.loadfrom(zipped_base_data, 'routes.txt', - lambda r: r['route_id'] in self.selected_route_ids) - self.route_name_by_route_id = {} - for r in routes: - self.route_name_by_route_id[r['route_id']] = r['route_short_name'] +class GTFSClient(): + def __init__(self, feed_name: str, stop_names: list[str], update_queue: queue.Queue, update_interval_seconds: int = 60): + self.stop_names = stop_names + self.feed = gk.read_feed(feed_name, dist_units='km') + self.stop_ids = self.__wanted_stop_ids() # Schedule refresh - if update_interval_seconds: + self.update_queue = update_queue + if update_interval_seconds and update_queue: self._refresh_thread = threading.Thread(target=lambda: every(self._update_interval_seconds, self.refresh)) - def loadfrom(self, zipfile: zipfile.ZipFile, name: str, filter: callable = None) -> map: + def __wanted_stop_ids(self) -> pd.core.frame.DataFrame: """ - Load a CSV file from the zip + Return a DataFrame with the ID and names of the chosen stop(s) as requested in station_names """ - with zipfile.open(name, "r") as datafile: - if not datafile: - raise Exception('File %s is not in the zipped data'.format(name)) - if filter: - result = [] - for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig")): - if filter(r): - result.append(r) - else: - result = [r for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig"))] - return result + stops = self.feed.stops[self.feed.stops["stop_name"].isin(self.stop_names)] + if stops.empty: + raise Exception("Stops is empty!") + return stops["stop_id"] - def update_schedule_from(self, gtfsr_json: str) -> list: + def __service_ids_active_at(self, when: datetime) -> pd.core.frame.DataFrame: """ - Creates a structure with the routes and arrival times from the - preloaded information, plus the gtfsr data received from the API + Returns the service IDs active at a particular point in time """ - # Parse JSON - gtfsr_data = json.loads(gtfsr_json) - entities = gtfsr_data['Entity'] + todays_date = when.strftime("%Y%m%d") + todays_weekday = when.strftime("%A").lower() + active_calendars = self.feed.calendar.query('start_date < @todays_date and end_date > @todays_date and {} == 1'.format(todays_weekday)) + return active_calendars + + def __current_service_ids(self) -> pd.core.series.Series: + """ + Filter the calendar entries to find all service ids that apply for today. + Returns an empty list if none do. + """ + + # Take the service IDs active today + now = datetime.datetime.now() + now_active = self.__service_ids_active_at(now) + if now_active.empty: + raise Exception("There are no service IDs for today!") + + # Merge with the service IDs for tomorrow (in case the number of trips spills over to tomorrow) + tomorrow = datetime.datetime.now() + datetime.timedelta(days=1) + tomorrow_active = self.__service_ids_active_at(tomorrow) + if tomorrow_active.empty: + raise Exception("There are no service IDs for tomorrow!") + + active_calendars = pd.concat([now_active, tomorrow_active]) + if active_calendars.empty: + raise Exception("The concatenation of today and tomorrow's calendars is empty. This should not happen.") + + return active_calendars["service_id"] + + + def __trip_ids_for_service_ids(self, service_ids: pd.core.series.Series) -> pd.core.series.Series: + """ + Returns a dataframe with the trip IDs for the given service IDs + """ + trips = self.feed.trips[self.feed.trips["service_id"].isin(service_ids)] + if trips.empty: + raise Exception("There are no active trips!") + + return trips["trip_id"] + + + def __next_n_buses(self, + trip_ids: pd.core.series.Series, + n: int) -> pd.core.frame.DataFrame: + now = datetime.datetime.now() + current_time = now.strftime("%H:%m:%S") + next_stops = self.feed.stop_times[self.feed.stop_times["stop_id"].isin(self.stop_ids) + & self.feed.stop_times["trip_id"].isin(trip_ids) + & (self.feed.stop_times["arrival_time"] > current_time)] + next_stops = next_stops.sort_values("arrival_time") + return next_stops[:n][["trip_id", "arrival_time", "stop_id"]] + + + def __join_data(self, next_buses: pd.core.frame.DataFrame) -> pd.core.frame.DataFrame: + """ + Enriches the stop data with the information from other dataframes in the feed + """ + joined_data = (next_buses + .join(self.feed.trips.set_index("trip_id"), on="trip_id") + .join(self.feed.stops.set_index("stop_id"), on="stop_id") + .join(self.feed.routes.set_index("route_id"), on="route_id")) + + return joined_data + + + def get_next_n_buses(self, num_entries: int) -> pd.core.frame.DataFrame: + """ + Returns a dataframe with the information of the next N buses arriving at the requested stops. + """ + service_ids = self.__current_service_ids() + trip_ids = self.__trip_ids_for_service_ids(service_ids) + next_buses = self.__next_n_buses(trip_ids, num_entries) + joined_data = self.__join_data(next_buses) + return joined_data + + + + def refresh(self): + """ + Create and enqueue the refreshed stop data + """ + arrivals = [] - for e in entities: - # Skip non-updates and invalid entries - if (e.get('IsDeleted') - or not e.get('TripUpdate') - or not e['TripUpdate'].get('Trip') - or not e['TripUpdate']['Trip'].get('TripId') in self.selected_trip_ids): - continue + buses = self.get_next_n_buses(5) + + for index, bus in buses.iterrows(): + arrival = ArrivalTime(stop_id = bus["stop_id"], + route_id = bus["route_short_name"], + destination= bus["route_long_name"].split(" - ")[1].strip(), + due_in_seconds = 0 + ) + arrivals.append(arrival) - # e contains an update for a trip we are interested in. - stop_times = e['TripUpdate'].get('StopTimeUpdate') - if not stop_times: - print('A TripUpdate entry does not have StopTimeUpdate:') - print(e) - continue - - for st in stop_times: - # Skip the stops we are not interested in - if not st.get('StopId') in self.selected_stop_ids: - continue - - # We have a stop time for one of our stops. Collect all info - trip_id = e['TripUpdate']['Trip']['TripId'] - trip = self.trip_by_trip_id[trip_id] - trip_destination = trip['trip_headsign'] - if len(trip_destination.split(' - ')) > 1: - trip_destination = trip_destination.split(' - ')[1] - route_id = self.trip_by_trip_id[trip_id]['route_id'] - route_name = self.route_name_by_route_id[route_id] - stop_name = self.stops_by_stop_id[st['StopId']]['stop_name'] - stop_time = self.calculate_delta( - self.stop_time_by_stop_and_trip_id[(st['StopId'], trip_id)]['arrival_time'], - st['Arrival'].get('Delay') or 0 - ) - current_timestamp = (mktime(datetime.now().timetuple())) - due_in_seconds = stop_time - current_timestamp - - arrival_time = ArrivalTime(stop_name, route_name, trip_destination, due_in_seconds) - arrivals.append(arrival_time) - arrivals = sorted(arrivals) + if self.update_queue: + self.update_queue.put(arrivals) return arrivals - def refresh(self) -> None: - """ Poll for new and updated information. Queue it for display update. """ - # Retrieve the updated json - url_opener = urllib.request.URLopener() - url_opener.addheader('x-api-key', API_KEY) - response = url_opener.open(GTFS_R_URL) - gtfs_r_json = response.file.read() - arrivals = gtfs.update_schedule_from(gtfs_r_json) - self._update_queue.put(arrivals) - - def calculate_delta(self, stop_time: str, delta: int) -> datetime: - """ - Returns a unix timestamp of - """ - stop_time_parts = list(map(lambda n: int(n), stop_time.split(':'))) - initial = datetime.combine(datetime.now(), - time(stop_time_parts[0], stop_time_parts[1], stop_time_parts[2])) - adjusted = initial + timedelta(seconds = delta) - return int(mktime(adjusted.timetuple())) - def every(delay, task) -> None: """ Auxilliary function to schedule updates. @@ -186,26 +150,6 @@ def every(delay, task) -> None: next_time += (time.time() - next_time) // delay * delay + delay -if __name__ == "__main__": - gtfs = GTFSClient([ - "Priory Walk, stop 1114", - "College Drive, stop 2410", - "Kimmage Road Lower, stop 2438", - "Brookfield, stop 2437" - ], queue.Queue(), None) - - if True: - o = urllib.request.URLopener() - o.addheader('x-api-key', API_KEY) - r = o.open(GTFS_R_URL) - if r.code != 200: - print(r.file.read()) - exit(1) - gtfs_r_json = r.file.read() - else: - gtfs_r_json = open('example.json').read() - - arrivals = gtfs.update_schedule_from(gtfs_r_json) - print(gtfs) - +c = GTFSClient('google_transit_combined.zip', ['College Drive, stop 2410', 'Priory Walk, stop 1114'], None, None) +print(c.refresh()) \ No newline at end of file