saving-willy-dev / src /input /input_observation.py
rmm
feat: extended InputObservation to contain species/prediction info
c915f7c
raw
history blame
10.3 kB
import hashlib
from input.input_validator import generate_random_md5
from numpy import ndarray
from streamlit.runtime.uploaded_file_manager import UploadedFile
import datetime
# autogenerated class to hold the input data
class InputObservation:
"""
A class to hold an input observation and associated metadata
Attributes:
image (ndarray):
The image associated with the observation.
latitude (float):
The latitude where the observation was made.
longitude (float):
The longitude where the observation was made.
author_email (str):
The email of the author of the observation.
image_datetime_raw (str):
The datetime extracted from the observation file
date (datetime.date):
Date of the observation
time (datetime.time):
Time of the observation
uploaded_file (UploadedFile):
The uploaded file associated with the observation.
image_md5 (str):
The MD5 hash of the image associated with the observation.
Methods:
__str__():
Returns a string representation of the observation.
__repr__():
Returns a string representation of the observation.
__eq__(other):
Checks if two observations are equal.
__ne__(other):
Checks if two observations are not equal.
show_diff(other):
Shows the differences between two observations.
to_dict():
Converts the observation to a dictionary.
from_dict(data):
Creates an observation from a dictionary.
from_input(input):
Creates an observation from another input observation.
"""
_inst_count = 0
def __init__(
self, image:ndarray=None, latitude:float=None, longitude:float=None,
author_email:str=None, image_datetime_raw:str=None,
date:datetime.date=None,
time:datetime.time=None,
uploaded_file:UploadedFile=None, image_md5:str=None):
self.image = image
self.latitude = latitude
self.longitude = longitude
self.author_email = author_email
self.image_datetime_raw = image_datetime_raw
self.date = date
self.time = time
self.uploaded_file = uploaded_file
self.image_md5 = image_md5
# attributes that get set after predictions/processing
self._top_predictions = []
self._selected_class = None
self._class_overriden = False
InputObservation._inst_count += 1
self._inst_id = InputObservation._inst_count
#dbg - temporarily give up if hash is not provided
if self.image_md5 is None:
raise ValueError(f"Image MD5 hash is required - {self._inst_id:3}.")
def set_top_predictions(self, top_predictions:list):
self._top_predictions = top_predictions
if len(top_predictions) > 0:
self.set_selected_class(top_predictions[0])
def set_selected_class(self, selected_class:str):
self._selected_class = selected_class
if selected_class != self._top_predictions[0]:
self.set_class_overriden(True)
def set_class_overriden(self, class_overriden:bool):
self._class_overriden = class_overriden
# add getters for the top_predictions, selected_class and class_overriden
@property
def top_predictions(self):
return self._top_predictions
@property
def selected_class(self):
return self._selected_class
@property
def class_overriden(self):
return self._class_overriden
# add a method to assign the image_md5 only once
def assign_image_md5(self):
raise DeprecationWarning("This method is deprecated. hash is a required constructor argument.")
if not self.image_md5:
self.image_md5 = hashlib.md5(self.uploaded_file.read()).hexdigest() if self.uploaded_file else generate_random_md5()
m_logger.debug(f"[D] Assigned image md5: {self.image_md5} for {self.uploaded_file}")
def __str__(self):
_im_str = "None" if self.image is None else f"image dims: {self.image.shape}"
return (
f"Observation: {_im_str}, {self.latitude}, {self.longitude}, "
f"{self.author_email}, {self.image_datetime_raw}, {self.date}, "
f"{self.time}, {self.uploaded_file}, {self.image_md5}"
)
def __repr__(self):
_im_str = "None" if self.image is None else f"image dims: {self.image.shape}"
return (
f"Observation: "
f"Image: {_im_str}, "
f"Latitude: {self.latitude}, "
f"Longitude: {self.longitude}, "
f"Author Email: {self.author_email}, "
f"raw timestamp: {self.image_datetime_raw}, "
f"Date: {self.date}, "
f"Time: {self.time}, "
f"Uploaded Filename: {self.uploaded_file}"
f"Image MD5 hash: {self.image_md5}"
)
def __eq__(self, other):
# TODO: ensure this covers all the attributes (some have been added?)
# - except inst_id which is unique
_image_equality = False
if self.image is None or other.image is None:
_image_equality = other.image == self.image
else: # maybe strong assumption: both are correctly ndarray.. should I test types intead?
_image_equality = (self.image == other.image).all()
equality = (
#self.image == other.image and
_image_equality and
self.latitude == other.latitude and
self.longitude == other.longitude and
self.author_email == other.author_email and
self.image_datetime_raw == other.image_datetime_raw and
self.date == other.date and
# temporarily skip time, it is followed by the clock and that is always differnt
#self.time == other.time and
self.uploaded_file == other.uploaded_file and
self.image_md5 == other.image_md5
)
return equality
# define a function show_diff(other) that shows the differences between two observations
# only highlight the differences, if element is the same don't show it
# have a summary at the top that shows if the observations are the same or not
def show_diff(self, other):
"""Show the differences between two observations"""
differences = []
if self.image is None or other.image is None:
if other.image != self.image:
differences.append(f" Image is different. (types mismatch: {type(self.image)} vs {type(other.image)})")
else:
if (self.image != other.image).any():
cnt = (self.image != other.image).sum()
differences.append(f" Image is different: {cnt} different pixels.")
if self.latitude != other.latitude:
differences.append(f" Latitude is different. (self: {self.latitude}, other: {other.latitude})")
if self.longitude != other.longitude:
differences.append(f" Longitude is different. (self: {self.longitude}, other: {other.longitude})")
if self.author_email != other.author_email:
differences.append(f" Author email is different. (self: {self.author_email}, other: {other.author_email})")
if self.image_datetime_raw != other.image_datetime_raw:
differences.append(f" Date is different. (self: {self.image_datetime_raw}, other: {other.image_datetime_raw})")
if self.date != other.date:
differences.append(f" Date is different. (self: {self.date}, other: {other.date})")
if self.time != other.time:
differences.append(f" Time is different. (self: {self.time}, other: {other.time})")
if self.uploaded_file != other.uploaded_file:
differences.append(" Uploaded filename is different.")
if self.image_md5 != other.image_md5:
differences.append(" Image MD5 hash is different.")
if differences:
print(f"Observations have {len(differences)} differences:")
for diff in differences:
print(diff)
else:
print("Observations are the same.")
def __ne__(self, other):
return not self.__eq__(other)
def to_dict(self):
return {
#"image": self.image,
"image_filename": self.uploaded_file.name if self.uploaded_file else None,
"image_md5": self.image_md5,
#"image_md5": hashlib.md5(self.uploaded_file.read()).hexdigest() if self.uploaded_file else generate_random_md5(),
"latitude": self.latitude,
"longitude": self.longitude,
"author_email": self.author_email,
"image_datetime_raw": self.image_datetime_raw,
"date": str(self.date),
"time": str(self.time),
"selected_class": self._selected_class,
"top_prediction": self._top_predictions[0] if len(self._top_predictions) else None,
"class_overriden": self._class_overriden,
#"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
}
@classmethod
def from_dict(cls, data):
return cls(
image=data.get("image"),
latitude=data.get("latitude"),
longitude=data.get("longitude"),
author_email=data.get("author_email"),
image_datetime_raw=data.get("image_datetime_raw"),
date=data.get("date"),
time=data.get("time"),
uploaded_file=data.get("uploaded_file"),
image_hash=data.get("image_md5")
)
@classmethod
def from_input(cls, input):
return cls(
image=input.image,
latitude=input.latitude,
longitude=input.longitude,
author_email=input.author_email,
image_datetime_raw=input.image_datetime_raw,
date=input.date,
time=input.time,
uploaded_file=input.uploaded_file,
image_hash=input.image_hash
)