Add initial code to migrate from the SOAP API client to GTFS-R
This commit is contained in:
parent
ca57d1c5ec
commit
b04648b1e2
|
|
@ -0,0 +1,211 @@
|
|||
import csv
|
||||
from datetime import datetime, time, timedelta
|
||||
import queue
|
||||
import threading
|
||||
from time import mktime
|
||||
from io import TextIOWrapper
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from arrival_times import ArrivalTime
|
||||
|
||||
# Constants and configuration
|
||||
GTFS_BASE_DATA_URL = "https://www.transportforireland.ie/transitData/google_transit_combined.zip"
|
||||
GTFS_R_URL = "https://api.nationaltransport.ie/gtfsr/v1?format=json"
|
||||
API_KEY = '470fcdd00bfe45c188fb236757d2df4f'
|
||||
BASE_DATA_MINSIZE = 20000000 # The zipped base data should be over 20-ish megabytes
|
||||
|
||||
class GTFSClient:
|
||||
def __init__(self, wanted_stops : list[str], update_queue: queue.Queue, update_interval_seconds: int = 60):
|
||||
self.wanted_stops = wanted_stops
|
||||
self._update_queue = update_queue
|
||||
self._update_interval_seconds = update_interval_seconds
|
||||
|
||||
# Check that the base data exists, and download it if it doesn't
|
||||
base_data_file_name = GTFS_BASE_DATA_URL.split('/')[-1]
|
||||
if not os.path.isfile(base_data_file_name):
|
||||
try:
|
||||
urllib.request.urlretrieve(GTFS_BASE_DATA_URL, base_data_file_name)
|
||||
if not os.path.isfile(base_data_file_name):
|
||||
raise Exception("The file %s was not downloaded.".format(base_data_file_name))
|
||||
if os.path.getsize(base_data_file_name) < BASE_DATA_MINSIZE:
|
||||
raise Exception("The base data file {} was too small.".format(base_data_file_name))
|
||||
except Exception as e:
|
||||
print("Error downloading base data: {}".format(str(e)))
|
||||
raise e
|
||||
|
||||
# Preload the entities from the base data
|
||||
with zipfile.ZipFile(base_data_file_name) as zipped_base_data:
|
||||
# Load stops and select the stop IDs we are interested in
|
||||
print ('Loading stops...')
|
||||
stops = self.loadfrom(zipped_base_data, "stops.txt",
|
||||
lambda s: s['stop_name'] in wanted_stops)
|
||||
self.selected_stop_ids = set([s['stop_id'] for s in stops])
|
||||
self.stops_by_stop_id = {}
|
||||
for stop in stops:
|
||||
self.stops_by_stop_id[stop['stop_id']] = stop
|
||||
|
||||
# Load the stop times for the selected stops
|
||||
print ('Loading stop times...')
|
||||
self.stop_time_by_stop_and_trip_id = {}
|
||||
self.selected_trip_ids = set()
|
||||
stop_times = self.loadfrom(zipped_base_data, "stop_times.txt",
|
||||
lambda st: st['stop_id'] in self.selected_stop_ids )
|
||||
for st in stop_times:
|
||||
self.stop_time_by_stop_and_trip_id[(st['stop_id'], st['trip_id'])] = st
|
||||
self.selected_trip_ids.add(st['trip_id'])
|
||||
|
||||
# Load the trips that include the selected stops
|
||||
print ('Loading trips...')
|
||||
self.trip_by_trip_id = {}
|
||||
self.selected_route_ids = set()
|
||||
trips = self.loadfrom(zipped_base_data, "trips.txt",
|
||||
lambda t: t['trip_id'] in self.selected_trip_ids)
|
||||
for t in trips:
|
||||
self.trip_by_trip_id[t['trip_id']] = t
|
||||
self.selected_route_ids.add(t['route_id'])
|
||||
|
||||
# Load the names of the routes for the selected trips
|
||||
routes = self.loadfrom(zipped_base_data, 'routes.txt',
|
||||
lambda r: r['route_id'] in self.selected_route_ids)
|
||||
self.route_name_by_route_id = {}
|
||||
for r in routes:
|
||||
self.route_name_by_route_id[r['route_id']] = r['route_short_name']
|
||||
|
||||
# Schedule refresh
|
||||
if update_interval_seconds:
|
||||
self._refresh_thread = threading.Thread(target=lambda: every(self._update_interval_seconds, self.refresh))
|
||||
|
||||
|
||||
def loadfrom(self, zipfile: zipfile.ZipFile, name: str, filter: callable = None) -> map:
|
||||
"""
|
||||
Load a CSV file from the zip
|
||||
"""
|
||||
with zipfile.open(name, "r") as datafile:
|
||||
if not datafile:
|
||||
raise Exception('File %s is not in the zipped data'.format(name))
|
||||
if filter:
|
||||
result = []
|
||||
for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig")):
|
||||
if filter(r):
|
||||
result.append(r)
|
||||
else:
|
||||
result = [r for r in csv.DictReader(TextIOWrapper(datafile, "utf-8-sig"))]
|
||||
return result
|
||||
|
||||
|
||||
def update_schedule_from(self, gtfsr_json: str) -> list:
|
||||
"""
|
||||
Creates a structure with the routes and arrival times from the
|
||||
preloaded information, plus the gtfsr data received from the API
|
||||
"""
|
||||
# Parse JSON
|
||||
gtfsr_data = json.loads(gtfsr_json)
|
||||
entities = gtfsr_data['Entity']
|
||||
|
||||
arrivals = []
|
||||
|
||||
for e in entities:
|
||||
# Skip non-updates and invalid entries
|
||||
if (e.get('IsDeleted')
|
||||
or not e.get('TripUpdate')
|
||||
or not e['TripUpdate'].get('Trip')
|
||||
or not e['TripUpdate']['Trip'].get('TripId') in self.selected_trip_ids):
|
||||
continue
|
||||
|
||||
# e contains an update for a trip we are interested in.
|
||||
stop_times = e['TripUpdate'].get('StopTimeUpdate')
|
||||
if not stop_times:
|
||||
print('A TripUpdate entry does not have StopTimeUpdate:')
|
||||
print(e)
|
||||
continue
|
||||
|
||||
for st in stop_times:
|
||||
# Skip the stops we are not interested in
|
||||
if not st.get('StopId') in self.selected_stop_ids:
|
||||
continue
|
||||
|
||||
# We have a stop time for one of our stops. Collect all info
|
||||
trip_id = e['TripUpdate']['Trip']['TripId']
|
||||
trip = self.trip_by_trip_id[trip_id]
|
||||
trip_destination = trip['trip_headsign']
|
||||
if len(trip_destination.split(' - ')) > 1:
|
||||
trip_destination = trip_destination.split(' - ')[1]
|
||||
route_id = self.trip_by_trip_id[trip_id]['route_id']
|
||||
route_name = self.route_name_by_route_id[route_id]
|
||||
stop_name = self.stops_by_stop_id[st['StopId']]['stop_name']
|
||||
stop_time = self.calculate_delta(
|
||||
self.stop_time_by_stop_and_trip_id[(st['StopId'], trip_id)]['arrival_time'],
|
||||
st['Arrival'].get('Delay') or 0
|
||||
)
|
||||
current_timestamp = (mktime(datetime.now().timetuple()))
|
||||
due_in_seconds = stop_time - current_timestamp
|
||||
|
||||
arrival_time = ArrivalTime(stop_name, route_name, trip_destination, due_in_seconds)
|
||||
arrivals.append(arrival_time)
|
||||
arrivals = sorted(arrivals)
|
||||
return arrivals
|
||||
|
||||
def refresh(self) -> None:
|
||||
""" Poll for new and updated information. Queue it for display update. """
|
||||
# Retrieve the updated json
|
||||
url_opener = urllib.request.URLopener()
|
||||
url_opener.addheader('x-api-key', API_KEY)
|
||||
response = url_opener.open(GTFS_R_URL)
|
||||
gtfs_r_json = response.file.read()
|
||||
arrivals = gtfs.update_schedule_from(gtfs_r_json)
|
||||
self._update_queue.put(arrivals)
|
||||
|
||||
def calculate_delta(self, stop_time: str, delta: int) -> datetime:
|
||||
"""
|
||||
Returns a unix timestamp of
|
||||
"""
|
||||
stop_time_parts = list(map(lambda n: int(n), stop_time.split(':')))
|
||||
initial = datetime.combine(datetime.now(),
|
||||
time(stop_time_parts[0], stop_time_parts[1], stop_time_parts[2]))
|
||||
adjusted = initial + timedelta(seconds = delta)
|
||||
return int(mktime(adjusted.timetuple()))
|
||||
|
||||
|
||||
def every(delay, task) -> None:
|
||||
""" Auxilliary function to schedule updates.
|
||||
Taken from https://stackoverflow.com/questions/474528/what-is-the-best-way-to-repeatedly-execute-a-function-every-x-seconds
|
||||
"""
|
||||
next_time = time() + delay
|
||||
while True:
|
||||
time.sleep(max(0, next_time - time()))
|
||||
try:
|
||||
task()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
# in production code you might want to have this instead of course:
|
||||
# logger.exception("Problem while executing repetitive task.")
|
||||
# skip tasks if we are behind schedule:
|
||||
next_time += (time.time() - next_time) // delay * delay + delay
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gtfs = GTFSClient([
|
||||
"Priory Walk, stop 1114",
|
||||
"College Drive, stop 2410",
|
||||
"Kimmage Road Lower, stop 2438",
|
||||
"Brookfield, stop 2437"
|
||||
], queue.Queue(), None)
|
||||
|
||||
if True:
|
||||
o = urllib.request.URLopener()
|
||||
o.addheader('x-api-key', API_KEY)
|
||||
r = o.open(GTFS_R_URL)
|
||||
if r.code != 200:
|
||||
print(r.file.read())
|
||||
exit(1)
|
||||
gtfs_r_json = r.file.read()
|
||||
else:
|
||||
gtfs_r_json = open('example.json').read()
|
||||
|
||||
arrivals = gtfs.update_schedule_from(gtfs_r_json)
|
||||
print(gtfs)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue