paloma99 commited on
Commit
f5b6824
·
verified ·
1 Parent(s): df7209b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -101
app.py CHANGED
@@ -7,40 +7,13 @@ theme = theme.Theme()
7
 
8
 
9
 
10
- import os
11
- import sys
12
- sys.path.append('../..')
13
-
14
- #langchain
15
- from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
16
- from langchain.embeddings import HuggingFaceEmbeddings
17
- from langchain.prompts import PromptTemplate
18
- from langchain.chains import RetrievalQA
19
- from langchain.prompts import ChatPromptTemplate
20
- from langchain.schema import StrOutputParser
21
- from langchain.schema.runnable import Runnable
22
- from langchain.schema.runnable.config import RunnableConfig
23
- from langchain.chains import (
24
- LLMChain, ConversationalRetrievalChain)
25
- from langchain.vectorstores import Chroma
26
- from langchain.memory import ConversationBufferMemory
27
- from langchain.chains import LLMChain
28
- from langchain.prompts.prompt import PromptTemplate
29
- from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
30
- from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder
31
- from langchain.document_loaders import PyPDFDirectoryLoader
32
-
33
- from langchain_community.llms import HuggingFaceHub
34
-
35
- from pydantic import BaseModel
36
- import shutil
37
 
38
 
39
 
40
 
41
 
42
  # Cell 1: Image Classification Model
43
- image_pipeline = pipeline(task="image-classification", model="guillen/vit-basura-test1")
44
 
45
  def predict_image(input_img):
46
  predictions = image_pipeline(input_img)
@@ -56,81 +29,11 @@ image_gradio_app = gr.Interface(
56
 
57
  # Cell 2: Chatbot Model
58
 
59
- loader = PyPDFDirectoryLoader('pdfs')
60
- data=loader.load()
61
- # split documents
62
- text_splitter = RecursiveCharacterTextSplitter(
63
- chunk_size=500,
64
- chunk_overlap=70,
65
- length_function=len
66
- )
67
- docs = text_splitter.split_documents(data)
68
- # define embedding
69
- embeddings = HuggingFaceEmbeddings(model_name='thenlper/gte-small')
70
- # create vector database from data
71
- persist_directory = 'docs/chroma/'
72
-
73
- # Remove old database files if any
74
- shutil.rmtree(persist_directory, ignore_errors=True)
75
- vectordb = Chroma.from_documents(
76
- documents=docs,
77
- embedding=embeddings,
78
- persist_directory=persist_directory
79
- )
80
- # define retriever
81
- retriever = vectordb.as_retriever(search_type="mmr")
82
- template = """
83
- Your name is AngryGreta and you are a recycling chatbot with the objective to anwer questions from user in English or Spanish /
84
- Use the following pieces of context to answer the question if the question is related with recycling /
85
- No more than two chunks of context /
86
- Answer in the same language of the question /
87
- Always say "thanks for asking!" at the end of the answer /
88
- If the context is not relevant, please answer the question by using your own knowledge about the topic.
89
-
90
- context: {context}
91
- question: {question}
92
- """
93
-
94
- # Create the chat prompt templates
95
- system_prompt = SystemMessagePromptTemplate.from_template(template)
96
- qa_prompt = ChatPromptTemplate(
97
- messages=[
98
- system_prompt,
99
- MessagesPlaceholder(variable_name="chat_history"),
100
- HumanMessagePromptTemplate.from_template("{question}")
101
- ]
102
- )
103
- llm = HuggingFaceHub(
104
- repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
105
- task="text-generation",
106
- model_kwargs={
107
- "max_new_tokens": 1024,
108
- "top_k": 30,
109
- "temperature": 0.1,
110
- "repetition_penalty": 1.03,
111
- },
112
- )
113
-
114
- memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
115
-
116
- qa_chain = ConversationalRetrievalChain.from_llm(
117
- llm = llm,
118
- memory = memory,
119
- retriever = retriever,
120
- verbose = True,
121
- combine_docs_chain_kwargs={'prompt': qa_prompt},
122
- get_chat_history = lambda h : h,
123
- rephrase_question = False,
124
- output_key = 'answer'
125
- )
126
-
127
- def chat_interface(question,history):
128
-
129
- result = qa_chain.invoke({"question": question})
130
- return result['answer'] # If the result is a string, return it directly
131
 
132
  chatbot_gradio_app = gr.ChatInterface(
133
- fn=chat_interface,
134
  title='Green Greta'
135
  )
136
 
 
7
 
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
 
13
 
14
 
15
  # Cell 1: Image Classification Model
16
+ image_pipeline = pipeline(task="image-classification", model="rocioadlc/TrashNet_ResNet152V2")
17
 
18
  def predict_image(input_img):
19
  predictions = image_pipeline(input_img)
 
29
 
30
  # Cell 2: Chatbot Model
31
 
32
+ def echo(message, history):
33
+ return message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  chatbot_gradio_app = gr.ChatInterface(
36
+ fn=echo,
37
  title='Green Greta'
38
  )
39