|
|
|
from ragatouille import RAGPretrainedModel |
|
import subprocess |
|
import json |
|
import spaces |
|
import firebase_admin |
|
from firebase_admin import credentials, firestore |
|
import logging |
|
from pathlib import Path |
|
from time import perf_counter |
|
from datetime import datetime |
|
import gradio as gr |
|
from jinja2 import Environment, FileSystemLoader |
|
import numpy as np |
|
from sentence_transformers import CrossEncoder |
|
from huggingface_hub import InferenceClient |
|
from os import getenv |
|
|
|
from backend.query_llm import generate_hf, generate_openai |
|
from backend.semantic_search import table, retriever |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
VECTOR_COLUMN_NAME = "vector" |
|
TEXT_COLUMN_NAME = "text" |
|
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN") |
|
proj_dir = Path(__file__).parent |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1",token=HF_TOKEN) |
|
|
|
env = Environment(loader=FileSystemLoader(proj_dir / 'templates')) |
|
|
|
|
|
template = env.get_template('template.j2') |
|
template_html = env.get_template('template_html.j2') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = ['My transhipment cargo is missing','can u explain and tabulate difference between b 17 bond and a warehousing bond', |
|
'What are benefits of the AEO Scheme and eligibility criteria?', |
|
'What are penalties for customs offences? ', 'what are penalties to customs officers misusing their powers under customs act?','What are eligibility criteria for exemption from cost recovery charges','list in detail what is procedure for obtaining new approval for openeing a CFS attached to an ICD'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_text(history, text): |
|
history = [] if history is None else history |
|
history = history + [(text, None)] |
|
return history, gr.Textbox(value="", interactive=False) |
|
|
|
|
|
def bot(history, cross_encoder): |
|
top_rerank = 25 |
|
top_k_rank = 20 |
|
query = history[-1][0] |
|
|
|
if not query: |
|
gr.Warning("Please submit a non-empty string as a prompt") |
|
raise ValueError("Empty string was submitted") |
|
|
|
logger.warning('Retrieving documents...') |
|
|
|
|
|
if cross_encoder=='(HIGH ACCURATE) ColBERT': |
|
gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait') |
|
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") |
|
RAG_db=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index') |
|
documents_full=RAG_db.search(query,k=top_k_rank) |
|
|
|
documents=[item['content'] for item in documents_full] |
|
|
|
prompt = template.render(documents=documents, query=query) |
|
prompt_html = template_html.render(documents=documents, query=query) |
|
|
|
generate_fn = generate_hf |
|
|
|
history[-1][1] = "" |
|
for character in generate_fn(prompt, history[:-1]): |
|
history[-1][1] = character |
|
yield history, prompt_html |
|
print('Final history is ',history) |
|
|
|
else: |
|
|
|
document_start = perf_counter() |
|
|
|
query_vec = retriever.encode(query) |
|
logger.warning(f'Finished query vec') |
|
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank) |
|
|
|
|
|
|
|
logger.warning(f'Finished search') |
|
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list() |
|
documents = [doc[TEXT_COLUMN_NAME] for doc in documents] |
|
logger.warning(f'start cross encoder {len(documents)}') |
|
|
|
query_doc_pair = [[query, doc] for doc in documents] |
|
if cross_encoder=='(FAST) MiniLM-L6v2' : |
|
cross_encoder1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
elif cross_encoder=='(ACCURATE) BGE reranker': |
|
cross_encoder1 = CrossEncoder('BAAI/bge-reranker-base') |
|
|
|
cross_scores = cross_encoder1.predict(query_doc_pair) |
|
sim_scores_argsort = list(reversed(np.argsort(cross_scores))) |
|
logger.warning(f'Finished cross encoder {len(documents)}') |
|
|
|
documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]] |
|
logger.warning(f'num documents {len(documents)}') |
|
|
|
document_time = perf_counter() - document_start |
|
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...') |
|
|
|
|
|
prompt = template.render(documents=documents, query=query) |
|
prompt_html = template_html.render(documents=documents, query=query) |
|
|
|
generate_fn = generate_hf |
|
|
|
history[-1][1] = "" |
|
for character in generate_fn(prompt, history[:-1]): |
|
history[-1][1] = character |
|
yield history, prompt_html |
|
print('Final history is ',history) |
|
|
|
|
|
def system_instructions(question_difficulty, topic,documents_str): |
|
return f"""<s> [INST] Your are a great teacher and your task is to create 10 questions with 4 choices with a {question_difficulty} difficulty about topic request " {topic} " only from the below given documents, {documents_str} then create an answers. Index in JSON format, the questions as "Q#":"" to "Q#":"", the four choices as "Q#:C1":"" to "Q#:C4":"", and the answers as "A#":"Q#:C#" to "A#":"Q#:C#". [/INST]""" |
|
|
|
RAG_db = gr.State() |
|
|
|
def load_model(): |
|
try: |
|
|
|
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") |
|
|
|
RAG_db.value = RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index') |
|
return 'Ready to Go!!' |
|
except Exception as e: |
|
return f"Error loading model: {e}" |
|
|
|
|
|
def generate_quiz(question_difficulty, topic): |
|
if not topic.strip(): |
|
return ['Please enter a valid topic.'] + [gr.Radio(visible=False) for _ in range(10)] |
|
|
|
top_k_rank = 10 |
|
|
|
try: |
|
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") |
|
RAG_db_ = RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index') |
|
gr.Warning('Model loaded!') |
|
except Exception as e: |
|
return [f"Error loading model: {e}"] + [gr.Radio(visible=False) for _ in range(10)] |
|
|
|
RAG_db_ = RAG_db.value |
|
documents_full = RAG_db_.search(topic, k=top_k_rank) |
|
|
|
generate_kwargs = dict( |
|
temperature=0.2, |
|
max_new_tokens=4000, |
|
top_p=0.95, |
|
repetition_penalty=1.0, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
question_radio_list = [] |
|
count = 0 |
|
while count <= 3: |
|
try: |
|
documents = [item['content'] for item in documents_full] |
|
document_summaries = [f"[DOCUMENT {i+1}]: {summary}{count}" for i, summary in enumerate(documents)] |
|
documents_str = '\n'.join(document_summaries) |
|
formatted_prompt = system_instructions(question_difficulty, topic, documents_str) |
|
|
|
pre_prompt = [ |
|
{"role": "system", "content": formatted_prompt} |
|
] |
|
response = client.text_generation( |
|
formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False, |
|
) |
|
output_json = json.loads(f"{response}") |
|
|
|
global quiz_data |
|
quiz_data = output_json |
|
|
|
for question_num in range(1, 11): |
|
question_key = f"Q{question_num}" |
|
answer_key = f"A{question_num}" |
|
question = quiz_data.get(question_key) |
|
answer = quiz_data.get(quiz_data.get(answer_key)) |
|
|
|
if not question or not answer: |
|
continue |
|
|
|
choice_keys = [f"{question_key}:C{i}" for i in range(1, 5)] |
|
choice_list = [quiz_data.get(choice_key, "Choice not found") for choice_key in choice_keys] |
|
|
|
radio = gr.Radio(choices=choice_list, label=question, visible=True, interactive=True) |
|
question_radio_list.append(radio) |
|
|
|
if len(question_radio_list) == 10: |
|
break |
|
else: |
|
count += 1 |
|
continue |
|
except Exception as e: |
|
count += 1 |
|
if count == 3: |
|
return ['Sorry. Pls try with another topic!'] + [gr.Radio(visible=False) for _ in range(10)] |
|
continue |
|
|
|
return ['Quiz Generated!'] + question_radio_list |
|
|
|
def compare_answers(*user_answers): |
|
user_answer_list = user_answers |
|
answers_list = [quiz_data.get(quiz_data.get(f"A{question_num}")) for question_num in range(1, 11)] |
|
|
|
score = sum(1 for answer in user_answer_list if answer in answers_list) |
|
|
|
if score > 7: |
|
message = f"### Excellent! You got {score} out of 10!" |
|
elif score > 5: |
|
message = f"### Good! You got {score} out of 10!" |
|
else: |
|
message = f"### You got {score} out of 10! Donβt worry, you can prepare well and try better next time!" |
|
|
|
return message |
|
|
|
|
|
with gr.Blocks(theme='NoCrypt/miku') as CHATBOT: |
|
with gr.Row(): |
|
with gr.Column(scale=10): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.HTML(value="""<div style="color: #FF4500;"><h1>ADWITIYA-</h1> <h1><span style="color: #008000">Custom Manual Chatbot and Quizbot</span></h1> |
|
</div>""", elem_id='heading') |
|
|
|
gr.HTML(value=f""" |
|
<p style="font-family: sans-serif; font-size: 16px;"> |
|
Using GenAI for CBIC Capacity Building - A free chat bot developed by National Customs Targeting Center using Open source LLMs for CBIC Officers |
|
</p> |
|
""", elem_id='Sub-heading') |
|
|
|
gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 14px;">Developed by NCTC,Mumbai . Suggestions may be sent to <a href="mailto:[email protected]" style="color: #00008B; font-style: italic;">[email protected]</a>.</p>""", elem_id='Sub-heading1 ') |
|
|
|
with gr.Column(scale=3): |
|
gr.Image(value='logo.png',height=200,width=200) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot( |
|
[], |
|
elem_id="chatbot", |
|
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg', |
|
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'), |
|
bubble_full_width=False, |
|
show_copy_button=True, |
|
show_share_button=True, |
|
) |
|
|
|
with gr.Row(): |
|
txt = gr.Textbox( |
|
scale=3, |
|
show_label=False, |
|
placeholder="Enter text and press enter", |
|
container=False, |
|
) |
|
txt_btn = gr.Button(value="Submit text", scale=1) |
|
|
|
cross_encoder = gr.Radio(choices=['(FAST) MiniLM-L6v2','(ACCURATE) BGE reranker','(HIGH ACCURATE) ColBERT'], value='(ACCURATE) BGE reranker',label="Embeddings", info="Only First query to Colbert may take litte time)") |
|
|
|
prompt_html = gr.HTML() |
|
|
|
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( |
|
bot, [chatbot, cross_encoder], [chatbot, prompt_html]) |
|
|
|
|
|
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) |
|
|
|
|
|
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( |
|
bot, [chatbot, cross_encoder], [chatbot, prompt_html]) |
|
|
|
|
|
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) |
|
|
|
|
|
gr.Examples(examples, txt) |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Quiz Maker", theme=gr.themes.Default(primary_hue="green", secondary_hue="green"), css="style.css") as QUIZBOT: |
|
with gr.Column(scale=4): |
|
gr.HTML(""" |
|
<center> |
|
<h1><span style="color: purple;">ADWITIYA</span> Customs Manual Quizbot</h1> |
|
<h2>Generative AI-powered Capacity building for Training Officers</h2> |
|
<i>β οΈ NACIN Faculties create quiz from any topic dynamically for classroom evaluation after their sessions! β οΈ</i> |
|
</center> |
|
""") |
|
|
|
with gr.Column(scale=2): |
|
gr.HTML(""" |
|
<center> |
|
|
|
<h2>Ready!</h2> |
|
|
|
</center> |
|
""") |
|
|
|
|
|
|
|
|
|
topic = gr.Textbox(label="Enter the Topic for Quiz", placeholder="Write any topic/details from Customs Manual") |
|
|
|
with gr.Row(): |
|
radio = gr.Radio(["easy", "average", "hard"], label="How difficult should the quiz be?") |
|
|
|
generate_quiz_btn = gr.Button("Generate Quiz!π") |
|
quiz_msg = gr.Textbox() |
|
|
|
question_radios = [gr.Radio(visible=False) for _ in range(10)] |
|
|
|
generate_quiz_btn.click( |
|
fn=generate_quiz, |
|
inputs=[radio, topic], |
|
outputs=[quiz_msg] + question_radios |
|
) |
|
|
|
check_button = gr.Button("Check Score") |
|
score_textbox = gr.Markdown() |
|
|
|
check_button.click( |
|
fn=compare_answers, |
|
inputs=question_radios, |
|
outputs=score_textbox |
|
) |
|
|
|
demo = gr.TabbedInterface([CHATBOT, QUIZBOT], ["AI ChatBot", "AI Quizbot"]) |
|
demo.queue() |
|
demo.launch(debug=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|