Moved from ad-hoc parsing of GTFS data to gtfs-kit

This commit is contained in:
Nahuel Lofeudo 2023-03-25 16:41:19 -03:00
parent e961c7bc42
commit 54a7e7da06
2 changed files with 118 additions and 174 deletions

View File

@ -1,7 +1,7 @@
class ArrivalTime():
""" Represents the arrival times of buses at one of the configured stops """
def __init__(self, stop_id: int, route_id: str, destination: str, due_in_seconds: int) -> None:
def __init__(self, stop_id: str, route_id: str, destination: str, due_in_seconds: int) -> None:
self.stop_id = stop_id
self.route_id = route_id
self.destination = destination

View File

@ -1,173 +1,137 @@
import csv
from datetime import datetime, time, timedelta
import queue
import threading
from time import mktime
from io import TextIOWrapper
import json
import os
import traceback
import urllib.request
import zipfile
from arrival_times import ArrivalTime
import datetime
import gtfs_kit as gk
import pandas as pd
import queue
import time
import threading
import traceback
# Constants and configuration
GTFS_BASE_DATA_URL = "https://www.transportforireland.ie/transitData/google_transit_combined.zip"
GTFS_R_URL = "https://api.nationaltransport.ie/gtfsr/v1?format=json"
API_KEY = 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
BASE_DATA_MINSIZE = 20000000 # The zipped base data should be over 20-ish megabytes
class GTFSClient:
def __init__(self, wanted_stops : list[str], update_queue: queue.Queue, update_interval_seconds: int = 60):
self.wanted_stops = wanted_stops
self._update_queue = update_queue
self._update_interval_seconds = update_interval_seconds
# Check that the base data exists, and download it if it doesn't
base_data_file_name = GTFS_BASE_DATA_URL.split('/')[-1]
if not os.path.isfile(base_data_file_name):
try:
urllib.request.urlretrieve(GTFS_BASE_DATA_URL, base_data_file_name)
if not os.path.isfile(base_data_file_name):
raise Exception("The file %s was not downloaded.".format(base_data_file_name))
if os.path.getsize(base_data_file_name) < BASE_DATA_MINSIZE:
raise Exception("The base data file {} was too small.".format(base_data_file_name))
except Exception as e:
print("Error downloading base data: {}".format(str(e)))
raise e
# Preload the entities from the base data
with zipfile.ZipFile(base_data_file_name) as zipped_base_data:
# Load stops and select the stop IDs we are interested in
print ('Loading stops...')
stops = self.loadfrom(zipped_base_data, "stops.txt",
lambda s: s['stop_name'] in wanted_stops)
self.selected_stop_ids = set([s['stop_id'] for s in stops])
self.stops_by_stop_id = {}
for stop in stops:
self.stops_by_stop_id[stop['stop_id']] = stop
# Load the stop times for the selected stops
print ('Loading stop times...')
self.stop_time_by_stop_and_trip_id = {}
self.selected_trip_ids = set()
stop_times = self.loadfrom(zipped_base_data, "stop_times.txt",
lambda st: st['stop_id'] in self.selected_stop_ids )
for st in stop_times:
self.stop_time_by_stop_and_trip_id[(st['stop_id'], st['trip_id'])] = st
self.selected_trip_ids.add(st['trip_id'])
# Load the trips that include the selected stops
print ('Loading trips...')
self.trip_by_trip_id = {}
self.selected_route_ids = set()
trips = self.loadfrom(zipped_base_data, "trips.txt",
lambda t: t['trip_id'] in self.selected_trip_ids)
for t in trips:
self.trip_by_trip_id[t['trip_id']] = t
self.selected_route_ids.add(t['route_id'])
# Load the names of the routes for the selected trips
routes = self.loadfrom(zipped_base_data, 'routes.txt',
lambda r: r['route_id'] in self.selected_route_ids)
self.route_name_by_route_id = {}
for r in routes:
self.route_name_by_route_id[r['route_id']] = r['route_short_name']
class GTFSClient():
def __init__(self, feed_name: str, stop_names: list[str], update_queue: queue.Queue, update_interval_seconds: int = 60):
self.stop_names = stop_names
self.feed = gk.read_feed(feed_name, dist_units='km')
self.stop_ids = self.__wanted_stop_ids()
# Schedule refresh
if update_interval_seconds:
self.update_queue = update_queue
if update_interval_seconds and update_queue:
self._refresh_thread = threading.Thread(target=lambda: every(self._update_interval_seconds, self.refresh))
def loadfrom(self, zipfile: zipfile.ZipFile, name: str, filter: callable = None) -> map:
def __wanted_stop_ids(self) -> pd.core.frame.DataFrame:
"""
Load a CSV file from the zip
Return a DataFrame with the ID and names of the chosen stop(s) as requested in station_names
"""
with zipfile.open(name, "r") as datafile:
if not datafile:
raise Exception('File %s is not in the zipped data'.format(name))
if filter:
result = []
for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig")):
if filter(r):
result.append(r)
else:
result = [r for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig"))]
return result
stops = self.feed.stops[self.feed.stops["stop_name"].isin(self.stop_names)]
if stops.empty:
raise Exception("Stops is empty!")
return stops["stop_id"]
def update_schedule_from(self, gtfsr_json: str) -> list:
def __service_ids_active_at(self, when: datetime) -> pd.core.frame.DataFrame:
"""
Creates a structure with the routes and arrival times from the
preloaded information, plus the gtfsr data received from the API
Returns the service IDs active at a particular point in time
"""
# Parse JSON
gtfsr_data = json.loads(gtfsr_json)
entities = gtfsr_data['Entity']
todays_date = when.strftime("%Y%m%d")
todays_weekday = when.strftime("%A").lower()
active_calendars = self.feed.calendar.query('start_date < @todays_date and end_date > @todays_date and {} == 1'.format(todays_weekday))
return active_calendars
def __current_service_ids(self) -> pd.core.series.Series:
"""
Filter the calendar entries to find all service ids that apply for today.
Returns an empty list if none do.
"""
# Take the service IDs active today
now = datetime.datetime.now()
now_active = self.__service_ids_active_at(now)
if now_active.empty:
raise Exception("There are no service IDs for today!")
# Merge with the service IDs for tomorrow (in case the number of trips spills over to tomorrow)
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
tomorrow_active = self.__service_ids_active_at(tomorrow)
if tomorrow_active.empty:
raise Exception("There are no service IDs for tomorrow!")
active_calendars = pd.concat([now_active, tomorrow_active])
if active_calendars.empty:
raise Exception("The concatenation of today and tomorrow's calendars is empty. This should not happen.")
return active_calendars["service_id"]
def __trip_ids_for_service_ids(self, service_ids: pd.core.series.Series) -> pd.core.series.Series:
"""
Returns a dataframe with the trip IDs for the given service IDs
"""
trips = self.feed.trips[self.feed.trips["service_id"].isin(service_ids)]
if trips.empty:
raise Exception("There are no active trips!")
return trips["trip_id"]
def __next_n_buses(self,
trip_ids: pd.core.series.Series,
n: int) -> pd.core.frame.DataFrame:
now = datetime.datetime.now()
current_time = now.strftime("%H:%m:%S")
next_stops = self.feed.stop_times[self.feed.stop_times["stop_id"].isin(self.stop_ids)
& self.feed.stop_times["trip_id"].isin(trip_ids)
& (self.feed.stop_times["arrival_time"] > current_time)]
next_stops = next_stops.sort_values("arrival_time")
return next_stops[:n][["trip_id", "arrival_time", "stop_id"]]
def __join_data(self, next_buses: pd.core.frame.DataFrame) -> pd.core.frame.DataFrame:
"""
Enriches the stop data with the information from other dataframes in the feed
"""
joined_data = (next_buses
.join(self.feed.trips.set_index("trip_id"), on="trip_id")
.join(self.feed.stops.set_index("stop_id"), on="stop_id")
.join(self.feed.routes.set_index("route_id"), on="route_id"))
return joined_data
def get_next_n_buses(self, num_entries: int) -> pd.core.frame.DataFrame:
"""
Returns a dataframe with the information of the next N buses arriving at the requested stops.
"""
service_ids = self.__current_service_ids()
trip_ids = self.__trip_ids_for_service_ids(service_ids)
next_buses = self.__next_n_buses(trip_ids, num_entries)
joined_data = self.__join_data(next_buses)
return joined_data
def refresh(self):
"""
Create and enqueue the refreshed stop data
"""
arrivals = []
for e in entities:
# Skip non-updates and invalid entries
if (e.get('IsDeleted')
or not e.get('TripUpdate')
or not e['TripUpdate'].get('Trip')
or not e['TripUpdate']['Trip'].get('TripId') in self.selected_trip_ids):
continue
buses = self.get_next_n_buses(5)
for index, bus in buses.iterrows():
arrival = ArrivalTime(stop_id = bus["stop_id"],
route_id = bus["route_short_name"],
destination= bus["route_long_name"].split(" - ")[1].strip(),
due_in_seconds = 0
)
arrivals.append(arrival)
# e contains an update for a trip we are interested in.
stop_times = e['TripUpdate'].get('StopTimeUpdate')
if not stop_times:
print('A TripUpdate entry does not have StopTimeUpdate:')
print(e)
continue
for st in stop_times:
# Skip the stops we are not interested in
if not st.get('StopId') in self.selected_stop_ids:
continue
# We have a stop time for one of our stops. Collect all info
trip_id = e['TripUpdate']['Trip']['TripId']
trip = self.trip_by_trip_id[trip_id]
trip_destination = trip['trip_headsign']
if len(trip_destination.split(' - ')) > 1:
trip_destination = trip_destination.split(' - ')[1]
route_id = self.trip_by_trip_id[trip_id]['route_id']
route_name = self.route_name_by_route_id[route_id]
stop_name = self.stops_by_stop_id[st['StopId']]['stop_name']
stop_time = self.calculate_delta(
self.stop_time_by_stop_and_trip_id[(st['StopId'], trip_id)]['arrival_time'],
st['Arrival'].get('Delay') or 0
)
current_timestamp = (mktime(datetime.now().timetuple()))
due_in_seconds = stop_time - current_timestamp
arrival_time = ArrivalTime(stop_name, route_name, trip_destination, due_in_seconds)
arrivals.append(arrival_time)
arrivals = sorted(arrivals)
if self.update_queue:
self.update_queue.put(arrivals)
return arrivals
def refresh(self) -> None:
""" Poll for new and updated information. Queue it for display update. """
# Retrieve the updated json
url_opener = urllib.request.URLopener()
url_opener.addheader('x-api-key', API_KEY)
response = url_opener.open(GTFS_R_URL)
gtfs_r_json = response.file.read()
arrivals = gtfs.update_schedule_from(gtfs_r_json)
self._update_queue.put(arrivals)
def calculate_delta(self, stop_time: str, delta: int) -> datetime:
"""
Returns a unix timestamp of
"""
stop_time_parts = list(map(lambda n: int(n), stop_time.split(':')))
initial = datetime.combine(datetime.now(),
time(stop_time_parts[0], stop_time_parts[1], stop_time_parts[2]))
adjusted = initial + timedelta(seconds = delta)
return int(mktime(adjusted.timetuple()))
def every(delay, task) -> None:
""" Auxilliary function to schedule updates.
@ -186,26 +150,6 @@ def every(delay, task) -> None:
next_time += (time.time() - next_time) // delay * delay + delay
if __name__ == "__main__":
gtfs = GTFSClient([
"Priory Walk, stop 1114",
"College Drive, stop 2410",
"Kimmage Road Lower, stop 2438",
"Brookfield, stop 2437"
], queue.Queue(), None)
if True:
o = urllib.request.URLopener()
o.addheader('x-api-key', API_KEY)
r = o.open(GTFS_R_URL)
if r.code != 200:
print(r.file.read())
exit(1)
gtfs_r_json = r.file.read()
else:
gtfs_r_json = open('example.json').read()
arrivals = gtfs.update_schedule_from(gtfs_r_json)
print(gtfs)
c = GTFSClient('google_transit_combined.zip', ['College Drive, stop 2410', 'Priory Walk, stop 1114'], None, None)
print(c.refresh())