Spaces:
Runtime error
Runtime error
Commit
·
c9fffcd
1
Parent(s):
96b7a63
Update components/custom_llm.py
Browse files- components/custom_llm.py +16 -0
components/custom_llm.py
CHANGED
@@ -6,6 +6,22 @@ from typing import Literal
|
|
6 |
import requests
|
7 |
from langchain.prompts import PromptTemplate, ChatPromptTemplate
|
8 |
from operator import itemgetter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class CustomLLM(LLM):
|
11 |
repo_id : str
|
|
|
6 |
import requests
|
7 |
from langchain.prompts import PromptTemplate, ChatPromptTemplate
|
8 |
from operator import itemgetter
|
9 |
+
import re
|
10 |
+
|
11 |
+
|
12 |
+
def format_captions(text):
|
13 |
+
sentences = list(filter(None,[x.strip() for x in re.split(r'[^A-Za-z0-9 -]', text)]))
|
14 |
+
print(len(sentences))
|
15 |
+
return "\n".join([f"{i+1}. {d}" for i,d in enumerate(sentences)])
|
16 |
+
|
17 |
+
def custom_chain():
|
18 |
+
|
19 |
+
prompt = PromptTemplate.from_template("<s><INST>Given the below template, create a list of image generation prompt with maximum 5 words for each number\n\n{template}<INST> ")
|
20 |
+
|
21 |
+
cap_llm = CustomLLM(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|"])
|
22 |
+
|
23 |
+
return {"template":lambda x:format_captions(x)} | prompt | cap_llm
|
24 |
+
|
25 |
|
26 |
class CustomLLM(LLM):
|
27 |
repo_id : str
|