Moved from ad-hoc parsing of GTFS data to gtfs-kit
This commit is contained in:
parent
e961c7bc42
commit
54a7e7da06
|
|
@ -1,7 +1,7 @@
|
|||
class ArrivalTime():
|
||||
""" Represents the arrival times of buses at one of the configured stops """
|
||||
|
||||
def __init__(self, stop_id: int, route_id: str, destination: str, due_in_seconds: int) -> None:
|
||||
def __init__(self, stop_id: str, route_id: str, destination: str, due_in_seconds: int) -> None:
|
||||
self.stop_id = stop_id
|
||||
self.route_id = route_id
|
||||
self.destination = destination
|
||||
|
|
|
|||
290
gtfs_client.py
290
gtfs_client.py
|
|
@ -1,173 +1,137 @@
|
|||
import csv
|
||||
from datetime import datetime, time, timedelta
|
||||
import queue
|
||||
import threading
|
||||
from time import mktime
|
||||
from io import TextIOWrapper
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from arrival_times import ArrivalTime
|
||||
import datetime
|
||||
import gtfs_kit as gk
|
||||
import pandas as pd
|
||||
import queue
|
||||
import time
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
# Constants and configuration
|
||||
GTFS_BASE_DATA_URL = "https://www.transportforireland.ie/transitData/google_transit_combined.zip"
|
||||
GTFS_R_URL = "https://api.nationaltransport.ie/gtfsr/v1?format=json"
|
||||
API_KEY = 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX'
|
||||
BASE_DATA_MINSIZE = 20000000 # The zipped base data should be over 20-ish megabytes
|
||||
|
||||
class GTFSClient:
|
||||
def __init__(self, wanted_stops : list[str], update_queue: queue.Queue, update_interval_seconds: int = 60):
|
||||
self.wanted_stops = wanted_stops
|
||||
self._update_queue = update_queue
|
||||
self._update_interval_seconds = update_interval_seconds
|
||||
|
||||
# Check that the base data exists, and download it if it doesn't
|
||||
base_data_file_name = GTFS_BASE_DATA_URL.split('/')[-1]
|
||||
if not os.path.isfile(base_data_file_name):
|
||||
try:
|
||||
urllib.request.urlretrieve(GTFS_BASE_DATA_URL, base_data_file_name)
|
||||
if not os.path.isfile(base_data_file_name):
|
||||
raise Exception("The file %s was not downloaded.".format(base_data_file_name))
|
||||
if os.path.getsize(base_data_file_name) < BASE_DATA_MINSIZE:
|
||||
raise Exception("The base data file {} was too small.".format(base_data_file_name))
|
||||
except Exception as e:
|
||||
print("Error downloading base data: {}".format(str(e)))
|
||||
raise e
|
||||
|
||||
# Preload the entities from the base data
|
||||
with zipfile.ZipFile(base_data_file_name) as zipped_base_data:
|
||||
# Load stops and select the stop IDs we are interested in
|
||||
print ('Loading stops...')
|
||||
stops = self.loadfrom(zipped_base_data, "stops.txt",
|
||||
lambda s: s['stop_name'] in wanted_stops)
|
||||
self.selected_stop_ids = set([s['stop_id'] for s in stops])
|
||||
self.stops_by_stop_id = {}
|
||||
for stop in stops:
|
||||
self.stops_by_stop_id[stop['stop_id']] = stop
|
||||
|
||||
# Load the stop times for the selected stops
|
||||
print ('Loading stop times...')
|
||||
self.stop_time_by_stop_and_trip_id = {}
|
||||
self.selected_trip_ids = set()
|
||||
stop_times = self.loadfrom(zipped_base_data, "stop_times.txt",
|
||||
lambda st: st['stop_id'] in self.selected_stop_ids )
|
||||
for st in stop_times:
|
||||
self.stop_time_by_stop_and_trip_id[(st['stop_id'], st['trip_id'])] = st
|
||||
self.selected_trip_ids.add(st['trip_id'])
|
||||
|
||||
# Load the trips that include the selected stops
|
||||
print ('Loading trips...')
|
||||
self.trip_by_trip_id = {}
|
||||
self.selected_route_ids = set()
|
||||
trips = self.loadfrom(zipped_base_data, "trips.txt",
|
||||
lambda t: t['trip_id'] in self.selected_trip_ids)
|
||||
for t in trips:
|
||||
self.trip_by_trip_id[t['trip_id']] = t
|
||||
self.selected_route_ids.add(t['route_id'])
|
||||
|
||||
# Load the names of the routes for the selected trips
|
||||
routes = self.loadfrom(zipped_base_data, 'routes.txt',
|
||||
lambda r: r['route_id'] in self.selected_route_ids)
|
||||
self.route_name_by_route_id = {}
|
||||
for r in routes:
|
||||
self.route_name_by_route_id[r['route_id']] = r['route_short_name']
|
||||
class GTFSClient():
|
||||
def __init__(self, feed_name: str, stop_names: list[str], update_queue: queue.Queue, update_interval_seconds: int = 60):
|
||||
self.stop_names = stop_names
|
||||
self.feed = gk.read_feed(feed_name, dist_units='km')
|
||||
self.stop_ids = self.__wanted_stop_ids()
|
||||
|
||||
# Schedule refresh
|
||||
if update_interval_seconds:
|
||||
self.update_queue = update_queue
|
||||
if update_interval_seconds and update_queue:
|
||||
self._refresh_thread = threading.Thread(target=lambda: every(self._update_interval_seconds, self.refresh))
|
||||
|
||||
|
||||
def loadfrom(self, zipfile: zipfile.ZipFile, name: str, filter: callable = None) -> map:
|
||||
def __wanted_stop_ids(self) -> pd.core.frame.DataFrame:
|
||||
"""
|
||||
Load a CSV file from the zip
|
||||
Return a DataFrame with the ID and names of the chosen stop(s) as requested in station_names
|
||||
"""
|
||||
with zipfile.open(name, "r") as datafile:
|
||||
if not datafile:
|
||||
raise Exception('File %s is not in the zipped data'.format(name))
|
||||
if filter:
|
||||
result = []
|
||||
for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig")):
|
||||
if filter(r):
|
||||
result.append(r)
|
||||
else:
|
||||
result = [r for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig"))]
|
||||
return result
|
||||
stops = self.feed.stops[self.feed.stops["stop_name"].isin(self.stop_names)]
|
||||
if stops.empty:
|
||||
raise Exception("Stops is empty!")
|
||||
return stops["stop_id"]
|
||||
|
||||
|
||||
def update_schedule_from(self, gtfsr_json: str) -> list:
|
||||
def __service_ids_active_at(self, when: datetime) -> pd.core.frame.DataFrame:
|
||||
"""
|
||||
Creates a structure with the routes and arrival times from the
|
||||
preloaded information, plus the gtfsr data received from the API
|
||||
Returns the service IDs active at a particular point in time
|
||||
"""
|
||||
# Parse JSON
|
||||
gtfsr_data = json.loads(gtfsr_json)
|
||||
entities = gtfsr_data['Entity']
|
||||
todays_date = when.strftime("%Y%m%d")
|
||||
todays_weekday = when.strftime("%A").lower()
|
||||
active_calendars = self.feed.calendar.query('start_date < @todays_date and end_date > @todays_date and {} == 1'.format(todays_weekday))
|
||||
return active_calendars
|
||||
|
||||
|
||||
def __current_service_ids(self) -> pd.core.series.Series:
|
||||
"""
|
||||
Filter the calendar entries to find all service ids that apply for today.
|
||||
Returns an empty list if none do.
|
||||
"""
|
||||
|
||||
# Take the service IDs active today
|
||||
now = datetime.datetime.now()
|
||||
now_active = self.__service_ids_active_at(now)
|
||||
if now_active.empty:
|
||||
raise Exception("There are no service IDs for today!")
|
||||
|
||||
# Merge with the service IDs for tomorrow (in case the number of trips spills over to tomorrow)
|
||||
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
|
||||
tomorrow_active = self.__service_ids_active_at(tomorrow)
|
||||
if tomorrow_active.empty:
|
||||
raise Exception("There are no service IDs for tomorrow!")
|
||||
|
||||
active_calendars = pd.concat([now_active, tomorrow_active])
|
||||
if active_calendars.empty:
|
||||
raise Exception("The concatenation of today and tomorrow's calendars is empty. This should not happen.")
|
||||
|
||||
return active_calendars["service_id"]
|
||||
|
||||
|
||||
def __trip_ids_for_service_ids(self, service_ids: pd.core.series.Series) -> pd.core.series.Series:
|
||||
"""
|
||||
Returns a dataframe with the trip IDs for the given service IDs
|
||||
"""
|
||||
trips = self.feed.trips[self.feed.trips["service_id"].isin(service_ids)]
|
||||
if trips.empty:
|
||||
raise Exception("There are no active trips!")
|
||||
|
||||
return trips["trip_id"]
|
||||
|
||||
|
||||
def __next_n_buses(self,
|
||||
trip_ids: pd.core.series.Series,
|
||||
n: int) -> pd.core.frame.DataFrame:
|
||||
now = datetime.datetime.now()
|
||||
current_time = now.strftime("%H:%m:%S")
|
||||
next_stops = self.feed.stop_times[self.feed.stop_times["stop_id"].isin(self.stop_ids)
|
||||
& self.feed.stop_times["trip_id"].isin(trip_ids)
|
||||
& (self.feed.stop_times["arrival_time"] > current_time)]
|
||||
next_stops = next_stops.sort_values("arrival_time")
|
||||
return next_stops[:n][["trip_id", "arrival_time", "stop_id"]]
|
||||
|
||||
|
||||
def __join_data(self, next_buses: pd.core.frame.DataFrame) -> pd.core.frame.DataFrame:
|
||||
"""
|
||||
Enriches the stop data with the information from other dataframes in the feed
|
||||
"""
|
||||
joined_data = (next_buses
|
||||
.join(self.feed.trips.set_index("trip_id"), on="trip_id")
|
||||
.join(self.feed.stops.set_index("stop_id"), on="stop_id")
|
||||
.join(self.feed.routes.set_index("route_id"), on="route_id"))
|
||||
|
||||
return joined_data
|
||||
|
||||
|
||||
def get_next_n_buses(self, num_entries: int) -> pd.core.frame.DataFrame:
|
||||
"""
|
||||
Returns a dataframe with the information of the next N buses arriving at the requested stops.
|
||||
"""
|
||||
service_ids = self.__current_service_ids()
|
||||
trip_ids = self.__trip_ids_for_service_ids(service_ids)
|
||||
next_buses = self.__next_n_buses(trip_ids, num_entries)
|
||||
joined_data = self.__join_data(next_buses)
|
||||
return joined_data
|
||||
|
||||
|
||||
|
||||
def refresh(self):
|
||||
"""
|
||||
Create and enqueue the refreshed stop data
|
||||
"""
|
||||
|
||||
arrivals = []
|
||||
|
||||
for e in entities:
|
||||
# Skip non-updates and invalid entries
|
||||
if (e.get('IsDeleted')
|
||||
or not e.get('TripUpdate')
|
||||
or not e['TripUpdate'].get('Trip')
|
||||
or not e['TripUpdate']['Trip'].get('TripId') in self.selected_trip_ids):
|
||||
continue
|
||||
buses = self.get_next_n_buses(5)
|
||||
|
||||
for index, bus in buses.iterrows():
|
||||
arrival = ArrivalTime(stop_id = bus["stop_id"],
|
||||
route_id = bus["route_short_name"],
|
||||
destination= bus["route_long_name"].split(" - ")[1].strip(),
|
||||
due_in_seconds = 0
|
||||
)
|
||||
arrivals.append(arrival)
|
||||
|
||||
# e contains an update for a trip we are interested in.
|
||||
stop_times = e['TripUpdate'].get('StopTimeUpdate')
|
||||
if not stop_times:
|
||||
print('A TripUpdate entry does not have StopTimeUpdate:')
|
||||
print(e)
|
||||
continue
|
||||
|
||||
for st in stop_times:
|
||||
# Skip the stops we are not interested in
|
||||
if not st.get('StopId') in self.selected_stop_ids:
|
||||
continue
|
||||
|
||||
# We have a stop time for one of our stops. Collect all info
|
||||
trip_id = e['TripUpdate']['Trip']['TripId']
|
||||
trip = self.trip_by_trip_id[trip_id]
|
||||
trip_destination = trip['trip_headsign']
|
||||
if len(trip_destination.split(' - ')) > 1:
|
||||
trip_destination = trip_destination.split(' - ')[1]
|
||||
route_id = self.trip_by_trip_id[trip_id]['route_id']
|
||||
route_name = self.route_name_by_route_id[route_id]
|
||||
stop_name = self.stops_by_stop_id[st['StopId']]['stop_name']
|
||||
stop_time = self.calculate_delta(
|
||||
self.stop_time_by_stop_and_trip_id[(st['StopId'], trip_id)]['arrival_time'],
|
||||
st['Arrival'].get('Delay') or 0
|
||||
)
|
||||
current_timestamp = (mktime(datetime.now().timetuple()))
|
||||
due_in_seconds = stop_time - current_timestamp
|
||||
|
||||
arrival_time = ArrivalTime(stop_name, route_name, trip_destination, due_in_seconds)
|
||||
arrivals.append(arrival_time)
|
||||
arrivals = sorted(arrivals)
|
||||
if self.update_queue:
|
||||
self.update_queue.put(arrivals)
|
||||
return arrivals
|
||||
|
||||
def refresh(self) -> None:
|
||||
""" Poll for new and updated information. Queue it for display update. """
|
||||
# Retrieve the updated json
|
||||
url_opener = urllib.request.URLopener()
|
||||
url_opener.addheader('x-api-key', API_KEY)
|
||||
response = url_opener.open(GTFS_R_URL)
|
||||
gtfs_r_json = response.file.read()
|
||||
arrivals = gtfs.update_schedule_from(gtfs_r_json)
|
||||
self._update_queue.put(arrivals)
|
||||
|
||||
def calculate_delta(self, stop_time: str, delta: int) -> datetime:
|
||||
"""
|
||||
Returns a unix timestamp of
|
||||
"""
|
||||
stop_time_parts = list(map(lambda n: int(n), stop_time.split(':')))
|
||||
initial = datetime.combine(datetime.now(),
|
||||
time(stop_time_parts[0], stop_time_parts[1], stop_time_parts[2]))
|
||||
adjusted = initial + timedelta(seconds = delta)
|
||||
return int(mktime(adjusted.timetuple()))
|
||||
|
||||
|
||||
def every(delay, task) -> None:
|
||||
""" Auxilliary function to schedule updates.
|
||||
|
|
@ -186,26 +150,6 @@ def every(delay, task) -> None:
|
|||
next_time += (time.time() - next_time) // delay * delay + delay
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gtfs = GTFSClient([
|
||||
"Priory Walk, stop 1114",
|
||||
"College Drive, stop 2410",
|
||||
"Kimmage Road Lower, stop 2438",
|
||||
"Brookfield, stop 2437"
|
||||
], queue.Queue(), None)
|
||||
|
||||
if True:
|
||||
o = urllib.request.URLopener()
|
||||
o.addheader('x-api-key', API_KEY)
|
||||
r = o.open(GTFS_R_URL)
|
||||
if r.code != 200:
|
||||
print(r.file.read())
|
||||
exit(1)
|
||||
gtfs_r_json = r.file.read()
|
||||
else:
|
||||
gtfs_r_json = open('example.json').read()
|
||||
|
||||
arrivals = gtfs.update_schedule_from(gtfs_r_json)
|
||||
print(gtfs)
|
||||
|
||||
c = GTFSClient('google_transit_combined.zip', ['College Drive, stop 2410', 'Priory Walk, stop 1114'], None, None)
|
||||
|
||||
print(c.refresh())
|
||||
Loading…
Reference in New Issue