Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -22,6 +22,7 @@ import asyncio
|
|
22 |
from typing import Optional
|
23 |
from dotenv import load_dotenv
|
24 |
import boto3
|
|
|
25 |
|
26 |
app = FastAPI()
|
27 |
|
@@ -33,11 +34,13 @@ app.add_middleware(
|
|
33 |
allow_headers=["*"],
|
34 |
)
|
35 |
|
|
|
|
|
36 |
load_dotenv()
|
37 |
token = os.environ.get("HF_TOKEN")
|
38 |
login(token)
|
39 |
|
40 |
-
prompt_model = "
|
41 |
magic_prompt_model = "Gustavosta/MagicPrompt-Stable-Diffusion"
|
42 |
options = {"use_cache": False, "wait_for_model": True}
|
43 |
parameters = {"return_full_text":False, "max_new_tokens":300}
|
@@ -88,33 +91,37 @@ async def core():
|
|
88 |
|
89 |
|
90 |
def getPrompt(prompt, modelID, attempts=1):
|
91 |
-
input = prompt
|
92 |
if modelID != magic_prompt_model:
|
93 |
-
tokenizer = AutoTokenizer.from_pretrained(modelID)
|
94 |
chat = [
|
95 |
{"role": "user", "content": prompt_base},
|
96 |
{"role": "assistant", "content": prompt_assistant},
|
97 |
{"role": "user", "content": prompt},
|
98 |
]
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
if attempts < 3:
|
113 |
getPrompt(prompt, modelID, attempts + 1)
|
114 |
-
except Exception as e:
|
115 |
-
print(f"An error occurred: {e}")
|
116 |
-
if attempts < 3:
|
117 |
-
getPrompt(prompt, modelID, attempts + 1)
|
118 |
return response.json()
|
119 |
|
120 |
@app.post("/inferencePrompt")
|
@@ -229,7 +236,7 @@ def lambda_image(prompt, modelID):
|
|
229 |
return response_data['body']
|
230 |
|
231 |
def inferenceAPI(model, item, attempts = 1):
|
232 |
-
print(model)
|
233 |
if attempts > 5:
|
234 |
return 'An error occured when Processing', model
|
235 |
prompt = item.prompt
|
@@ -285,7 +292,8 @@ def get_random_model(models):
|
|
285 |
print("Choosing randomly")
|
286 |
model = random.choice(models)
|
287 |
last_two_models.append(model)
|
288 |
-
last_two_models = last_two_models[-5:]
|
|
|
289 |
return model
|
290 |
|
291 |
def nsfw_check(item, attempts=1):
|
@@ -324,6 +332,7 @@ async def inference(item: Item):
|
|
324 |
print(activeModels['text-to-image'])
|
325 |
base64_img = ""
|
326 |
model = item.modelID
|
|
|
327 |
NSFW = False
|
328 |
try:
|
329 |
if item.image:
|
|
|
22 |
from typing import Optional
|
23 |
from dotenv import load_dotenv
|
24 |
import boto3
|
25 |
+
from groq import Groq
|
26 |
|
27 |
app = FastAPI()
|
28 |
|
|
|
34 |
allow_headers=["*"],
|
35 |
)
|
36 |
|
37 |
+
groqClient = Groq (api_key=os.environ.get("GROQ_API_KEY"))
|
38 |
+
|
39 |
load_dotenv()
|
40 |
token = os.environ.get("HF_TOKEN")
|
41 |
login(token)
|
42 |
|
43 |
+
prompt_model = "llama-3.1-8b-instant"
|
44 |
magic_prompt_model = "Gustavosta/MagicPrompt-Stable-Diffusion"
|
45 |
options = {"use_cache": False, "wait_for_model": True}
|
46 |
parameters = {"return_full_text":False, "max_new_tokens":300}
|
|
|
91 |
|
92 |
|
93 |
def getPrompt(prompt, modelID, attempts=1):
|
|
|
94 |
if modelID != magic_prompt_model:
|
|
|
95 |
chat = [
|
96 |
{"role": "user", "content": prompt_base},
|
97 |
{"role": "assistant", "content": prompt_assistant},
|
98 |
{"role": "user", "content": prompt},
|
99 |
]
|
100 |
+
try:
|
101 |
+
response = client.chat.completions.create(messages=chat, temperature=1, max_tokens=2048, top_p=1, stream=False, stop=None, model=modelID)
|
102 |
+
except Exception as e:
|
103 |
+
print(f"An error occurred: {e}")
|
104 |
+
if attempts < 3:
|
105 |
+
getPrompt(prompt, modelID, attempts + 1)
|
106 |
+
else:
|
107 |
+
try:
|
108 |
+
print(modelID)
|
109 |
+
apiData={"inputs":input, "parameters": parameters, "options": options, "timeout": 45}
|
110 |
+
response = requests.post(API_URL + modelID, headers=headers, data=json.dumps(apiData))
|
111 |
+
if response.status_code == 200:
|
112 |
+
try:
|
113 |
+
responseData = response.json()
|
114 |
+
return responseData
|
115 |
+
except ValueError as e:
|
116 |
+
print(f"Error parsing JSON: {e}")
|
117 |
+
else:
|
118 |
+
print(f"Error from API: {response.status_code} - {response.text}")
|
119 |
+
if attempts < 3:
|
120 |
+
getPrompt(prompt, modelID, attempts + 1)
|
121 |
+
except Exception as e:
|
122 |
+
print(f"An error occurred: {e}")
|
123 |
if attempts < 3:
|
124 |
getPrompt(prompt, modelID, attempts + 1)
|
|
|
|
|
|
|
|
|
125 |
return response.json()
|
126 |
|
127 |
@app.post("/inferencePrompt")
|
|
|
236 |
return response_data['body']
|
237 |
|
238 |
def inferenceAPI(model, item, attempts = 1):
|
239 |
+
print(f'Inference model {model}')
|
240 |
if attempts > 5:
|
241 |
return 'An error occured when Processing', model
|
242 |
prompt = item.prompt
|
|
|
292 |
print("Choosing randomly")
|
293 |
model = random.choice(models)
|
294 |
last_two_models.append(model)
|
295 |
+
last_two_models = last_two_models[-5:]
|
296 |
+
|
297 |
return model
|
298 |
|
299 |
def nsfw_check(item, attempts=1):
|
|
|
332 |
print(activeModels['text-to-image'])
|
333 |
base64_img = ""
|
334 |
model = item.modelID
|
335 |
+
print(f'Start Model {model}')
|
336 |
NSFW = False
|
337 |
try:
|
338 |
if item.image:
|