Spaces:
Runtime error
Runtime error
from inference.core.env import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, LAMBDA | |
from inference.core.models.classification_base import ( | |
ClassificationBaseOnnxRoboflowInferenceModel, | |
) | |
class VitClassification(ClassificationBaseOnnxRoboflowInferenceModel): | |
"""VitClassification handles classification inference | |
for Vision Transformer (ViT) models using ONNX. | |
Inherits: | |
ClassificationBaseOnnxRoboflowInferenceModel: Base class for ONNX Roboflow Inference. | |
ClassificationMixin: Mixin class providing classification-specific methods. | |
Attributes: | |
multiclass (bool): A flag that specifies if the model should handle multiclass classification. | |
""" | |
def __init__(self, *args, **kwargs): | |
"""Initializes the VitClassification instance. | |
Args: | |
*args: Variable length argument list. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
super().__init__(*args, **kwargs) | |
self.multiclass = self.environment.get("MULTICLASS", False) | |
def weights_file(self) -> str: | |
"""Determines the weights file to be used based on the availability of AWS keys. | |
If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'. | |
Otherwise, it returns the path to 'best.onnx'. | |
Returns: | |
str: Path to the weights file. | |
""" | |
if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and LAMBDA: | |
return "weights.onnx" | |
else: | |
return "best.onnx" | |