tangibleAI / models /base.py
Rahul Dubey
new file: actionFunctions.py
463297f
raw
history blame contribute delete
743 Bytes
from .realistic_vision_v6b1 import RealV6B1
from .sdxl import SDXL
class Model:
def __init__(self, modelName):
self.modelName = modelName
if self.modelName == 'SDXL':
self.modelObj = SDXL()
self.model = self.modelObj.load_model()
elif self.modelName == 'REALV6B1':
self.modelObj = RealV6B1()
self.model = self.modelObj.load_model()
else:
self.modelObj = None
self.model = None
def getModelState(self):
if self.modelObj is None:
return "Model Not Loaded"
else:
return "Model Loaded"
def predict(self, prompt):
return self.modelObj.predict(self.model, prompt)