Hatman commited on
Commit
efc47bc
·
verified ·
1 Parent(s): bf608d6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -22
main.py CHANGED
@@ -22,6 +22,7 @@ import asyncio
22
  from typing import Optional
23
  from dotenv import load_dotenv
24
  import boto3
 
25
 
26
  app = FastAPI()
27
 
@@ -33,11 +34,13 @@ app.add_middleware(
33
  allow_headers=["*"],
34
  )
35
 
 
 
36
  load_dotenv()
37
  token = os.environ.get("HF_TOKEN")
38
  login(token)
39
 
40
- prompt_model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
41
  magic_prompt_model = "Gustavosta/MagicPrompt-Stable-Diffusion"
42
  options = {"use_cache": False, "wait_for_model": True}
43
  parameters = {"return_full_text":False, "max_new_tokens":300}
@@ -88,33 +91,37 @@ async def core():
88
 
89
 
90
  def getPrompt(prompt, modelID, attempts=1):
91
- input = prompt
92
  if modelID != magic_prompt_model:
93
- tokenizer = AutoTokenizer.from_pretrained(modelID)
94
  chat = [
95
  {"role": "user", "content": prompt_base},
96
  {"role": "assistant", "content": prompt_assistant},
97
  {"role": "user", "content": prompt},
98
  ]
99
- input = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
100
- try:
101
- print(modelID)
102
- apiData={"inputs":input, "parameters": parameters, "options": options, "timeout": 45}
103
- response = requests.post(API_URL + modelID, headers=headers, data=json.dumps(apiData))
104
- if response.status_code == 200:
105
- try:
106
- responseData = response.json()
107
- return responseData
108
- except ValueError as e:
109
- print(f"Error parsing JSON: {e}")
110
- else:
111
- print(f"Error from API: {response.status_code} - {response.text}")
 
 
 
 
 
 
 
 
 
 
112
  if attempts < 3:
113
  getPrompt(prompt, modelID, attempts + 1)
114
- except Exception as e:
115
- print(f"An error occurred: {e}")
116
- if attempts < 3:
117
- getPrompt(prompt, modelID, attempts + 1)
118
  return response.json()
119
 
120
  @app.post("/inferencePrompt")
@@ -229,7 +236,7 @@ def lambda_image(prompt, modelID):
229
  return response_data['body']
230
 
231
  def inferenceAPI(model, item, attempts = 1):
232
- print(model)
233
  if attempts > 5:
234
  return 'An error occured when Processing', model
235
  prompt = item.prompt
@@ -285,7 +292,8 @@ def get_random_model(models):
285
  print("Choosing randomly")
286
  model = random.choice(models)
287
  last_two_models.append(model)
288
- last_two_models = last_two_models[-5:]
 
289
  return model
290
 
291
  def nsfw_check(item, attempts=1):
@@ -324,6 +332,7 @@ async def inference(item: Item):
324
  print(activeModels['text-to-image'])
325
  base64_img = ""
326
  model = item.modelID
 
327
  NSFW = False
328
  try:
329
  if item.image:
 
22
  from typing import Optional
23
  from dotenv import load_dotenv
24
  import boto3
25
+ from groq import Groq
26
 
27
  app = FastAPI()
28
 
 
34
  allow_headers=["*"],
35
  )
36
 
37
+ groqClient = Groq (api_key=os.environ.get("GROQ_API_KEY"))
38
+
39
  load_dotenv()
40
  token = os.environ.get("HF_TOKEN")
41
  login(token)
42
 
43
+ prompt_model = "llama-3.1-8b-instant"
44
  magic_prompt_model = "Gustavosta/MagicPrompt-Stable-Diffusion"
45
  options = {"use_cache": False, "wait_for_model": True}
46
  parameters = {"return_full_text":False, "max_new_tokens":300}
 
91
 
92
 
93
  def getPrompt(prompt, modelID, attempts=1):
 
94
  if modelID != magic_prompt_model:
 
95
  chat = [
96
  {"role": "user", "content": prompt_base},
97
  {"role": "assistant", "content": prompt_assistant},
98
  {"role": "user", "content": prompt},
99
  ]
100
+ try:
101
+ response = client.chat.completions.create(messages=chat, temperature=1, max_tokens=2048, top_p=1, stream=False, stop=None, model=modelID)
102
+ except Exception as e:
103
+ print(f"An error occurred: {e}")
104
+ if attempts < 3:
105
+ getPrompt(prompt, modelID, attempts + 1)
106
+ else:
107
+ try:
108
+ print(modelID)
109
+ apiData={"inputs":input, "parameters": parameters, "options": options, "timeout": 45}
110
+ response = requests.post(API_URL + modelID, headers=headers, data=json.dumps(apiData))
111
+ if response.status_code == 200:
112
+ try:
113
+ responseData = response.json()
114
+ return responseData
115
+ except ValueError as e:
116
+ print(f"Error parsing JSON: {e}")
117
+ else:
118
+ print(f"Error from API: {response.status_code} - {response.text}")
119
+ if attempts < 3:
120
+ getPrompt(prompt, modelID, attempts + 1)
121
+ except Exception as e:
122
+ print(f"An error occurred: {e}")
123
  if attempts < 3:
124
  getPrompt(prompt, modelID, attempts + 1)
 
 
 
 
125
  return response.json()
126
 
127
  @app.post("/inferencePrompt")
 
236
  return response_data['body']
237
 
238
  def inferenceAPI(model, item, attempts = 1):
239
+ print(f'Inference model {model}')
240
  if attempts > 5:
241
  return 'An error occured when Processing', model
242
  prompt = item.prompt
 
292
  print("Choosing randomly")
293
  model = random.choice(models)
294
  last_two_models.append(model)
295
+ last_two_models = last_two_models[-5:]
296
+
297
  return model
298
 
299
  def nsfw_check(item, attempts=1):
 
332
  print(activeModels['text-to-image'])
333
  base64_img = ""
334
  model = item.modelID
335
+ print(f'Start Model {model}')
336
  NSFW = False
337
  try:
338
  if item.image: