|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
from langchain_community.llms import CTransformers |
|
from langchain.prompts import FewShotChatMessagePromptTemplate, ChatPromptTemplate, FewShotPromptTemplate |
|
import gradio as gr |
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
|
|
from langchain_community.document_loaders import JSONLoader |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
from operator import itemgetter |
|
from langchain_core.output_parsers import JsonOutputParser |
|
from langchain_core.outputs import Generation |
|
from typing import Any, List, Optional, Type, TypeVar, Union |
|
|
|
|
|
MODEL_PATH = "TheBloke/Mistral-7B-Claude-Chat-GGUF" |
|
MODEL_FILE = "mistral-7b-claude-chat.Q4_K_M.gguf" |
|
MODEL_TYPE = "mistral" |
|
MAX_NEW_TOKENS = 100 |
|
temperature = 1 |
|
top_p = 0.95 |
|
top_k = 50 |
|
repetition_penalty = 1.5 |
|
|
|
|
|
llm = CTransformers( |
|
model = MODEL_PATH, |
|
model_file=MODEL_FILE, |
|
model_type = MODEL_TYPE, |
|
config = { |
|
"max_new_tokens":MAX_NEW_TOKENS, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"repetition_penalty": repetition_penalty, |
|
"last_n_tokens": 4, |
|
"stream": True, |
|
"gpu_layers": 0 |
|
} |
|
) |
|
|
|
|
|
examples = [ |
|
{ |
|
"query": "Please classify this name: Ketan Jogadankar", |
|
"answer":"""{ |
|
"name": "Ketan Jogadankar", |
|
"label": "person", |
|
"score": 0.99, |
|
"reason": "Ketan is a most famous first name and Jogadankar looks like a surname." |
|
}""" |
|
} |
|
] |
|
|
|
example_template = """ |
|
User: {query} |
|
{answer} |
|
""" |
|
|
|
example_prompt = ChatPromptTemplate.from_messages( |
|
[("human", "{query}"), |
|
("ai", "{answer}")] |
|
) |
|
|
|
prefix = """Act as an AI assistant that classifies names into 3 categories (person, business and other) based on the provided rules and example data. |
|
{format_instructions} |
|
Do not append any text to human input. |
|
Rules: |
|
* If the names contains the word "POD", classify it as a other. |
|
* If the names contains the word "trust", classify it as a other. |
|
* If the names contains the word "llc", classify it as a business. |
|
* If the name is non-profit organization then classify it as a other. |
|
Here are some examples: |
|
""" |
|
|
|
suffix = """Please classify this name: {name} |
|
""" |
|
|
|
few_shot_prompt_template = FewShotChatMessagePromptTemplate( |
|
examples = examples, |
|
example_prompt = example_prompt |
|
) |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system",prefix), |
|
few_shot_prompt_template, |
|
("human", suffix) |
|
] |
|
) |
|
|
|
format_instructions = """The output should be formatted as a JSON instance that conforms to the JSON schema below. |
|
Here is the output schema: |
|
``` |
|
{"properties": {"name": {"title": "Name", "description": "this is the input name passed by human", "type": "string"}, "label": {"title": "Label", "description": "this is the label predicted for input name", "type": "string"}, "score": {"title": "Score", "description": "This is confidence score for predicted label", "type": "number"}, "reason": {"title": "Reason", "description": "This is to explain why AI has predicted that label", "type": "string"}}, "required": ["name", "label", "score", "reason"]} |
|
``` |
|
""" |
|
|
|
data_loader = JSONLoader(file_path="document.json", |
|
jq_schema='.',text_content=False) |
|
data = data_loader.load() |
|
data = [doc.page_content for doc in data] |
|
|
|
splitter = CharacterTextSplitter(chunk_size=2, chunk_overlap=1) |
|
documents = splitter.create_documents(texts=data) |
|
|
|
docs_str = [doc.page_content for doc in documents] |
|
sentence_emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") |
|
|
|
db = Chroma.from_texts(docs_str, sentence_emb, persist_directory="./temp_db") |
|
db.persist() |
|
|
|
retriever = db.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={'k':1}) |
|
|
|
|
|
from pydantic import BaseModel, Field |
|
class NameClassification(BaseModel): |
|
name:str = Field(description="this is the input name passed by human") |
|
label:str = Field(description="this is the label predicted for input name") |
|
score:float = Field(description="This is confidence score for predicted label") |
|
reason:str = Field(description="This is to explain why AI has predicted that label") |
|
|
|
def remove_junks(self, text): |
|
start_index = text.index("{") |
|
stop_index = text.index("}") + 1 |
|
return text[start_index:stop_index+1] |
|
|
|
def parse(self, text): |
|
text = self.remove_junks(text) |
|
super().invoke(text) |
|
|
|
class CustomParser(JsonOutputParser): |
|
|
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: |
|
text = result[0].text |
|
text = text.strip() |
|
text = self.remove_junks(text) |
|
result = [Generation(text= text)] |
|
return super().parse_result(result=result,partial=partial) |
|
|
|
|
|
def remove_junks(self, text): |
|
start_index = text.index("{") |
|
stop_index = text.index("}") + 1 |
|
return text[start_index:stop_index+1] |
|
|
|
parser = CustomParser(pydantic_object=NameClassification) |
|
|
|
chain = ( |
|
{"context": itemgetter("name") | retriever, |
|
"format_instructions": itemgetter("format_instructions"), |
|
"name": itemgetter("name")} |
|
| prompt |
|
| llm |
|
| parser |
|
) |
|
|
|
""" |
|
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
|
""" |
|
|
|
|
|
def predict(message, history, min_hist_memo = 3): |
|
streamer = None |
|
while streamer == None: |
|
try: |
|
streamer = chain.invoke({"name":message, "format_instructions":format_instructions}) |
|
except: |
|
pass |
|
yield str(streamer) |
|
|
|
gr.ChatInterface(predict, title="Mistral 7B").queue().launch(debug=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|