Spaces:
Runtime error
Runtime error
Commit
·
c9fffcd
1
Parent(s):
96b7a63
Update components/custom_llm.py
Browse files- components/custom_llm.py +16 -0
components/custom_llm.py
CHANGED
|
@@ -6,6 +6,22 @@ from typing import Literal
|
|
| 6 |
import requests
|
| 7 |
from langchain.prompts import PromptTemplate, ChatPromptTemplate
|
| 8 |
from operator import itemgetter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class CustomLLM(LLM):
|
| 11 |
repo_id : str
|
|
|
|
| 6 |
import requests
|
| 7 |
from langchain.prompts import PromptTemplate, ChatPromptTemplate
|
| 8 |
from operator import itemgetter
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def format_captions(text):
|
| 13 |
+
sentences = list(filter(None,[x.strip() for x in re.split(r'[^A-Za-z0-9 -]', text)]))
|
| 14 |
+
print(len(sentences))
|
| 15 |
+
return "\n".join([f"{i+1}. {d}" for i,d in enumerate(sentences)])
|
| 16 |
+
|
| 17 |
+
def custom_chain():
|
| 18 |
+
|
| 19 |
+
prompt = PromptTemplate.from_template("<s><INST>Given the below template, create a list of image generation prompt with maximum 5 words for each number\n\n{template}<INST> ")
|
| 20 |
+
|
| 21 |
+
cap_llm = CustomLLM(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|"])
|
| 22 |
+
|
| 23 |
+
return {"template":lambda x:format_captions(x)} | prompt | cap_llm
|
| 24 |
+
|
| 25 |
|
| 26 |
class CustomLLM(LLM):
|
| 27 |
repo_id : str
|