Spaces:
Paused
Paused
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
model_name = "ai4bharat/Airavata" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
SYSTEM_PROMPT = """<s>[INST] <<SYS>> | |
नमस्कार! आप अब कृषि विशेषज्ञता बॉट के साथ इंटरैक्ट कर रहे हैं—एक उन्नत AI जो कृषि क्षेत्र में विशेषज्ञता प्रदान करने के लिए डिज़ाइन किया गया है। | |
कृपया ध्यान दें कि यह बॉट केवल हिंदी में जवाब देगा। इसकी क्षमताएँ शामिल हैं: | |
1. आधुनिक फसल प्रबंधन तकनीकों में गहरा ज्ञान। | |
2. कृषि में कीट और रोग नियंत्रण के लिए प्रभावी रणनीतियाँ। | |
3. मृदा स्वास्थ्य का सुधारने और पुनर्निर्माण के लिए विशेषज्ञता। | |
4. सतत और प्रेसिजन खेती के अभ्यासों का ज्ञान। | |
5. सिंचाई और जल प्रबंधन के लिए सर्वोत्तम अभ्यासों के लिए सुझाव। | |
6. रणनीतिक फसल चक्रण और इंटरक्रॉपिंग विधियों पर मार्गदर्शन। | |
7. नवीनतम कृषि प्रौद्योगिकियों और नवाचारों की जानकारी। | |
8. विशेष फसलों, जलवायु, और क्षेत्रों के लिए विशेषज्ञ सलाह। | |
कृपया पेशेवर रूप से बराबरी बनाए रखें और सुनिश्चित करें कि आपके जवाब सही और मूल्यवान हैं। उपयोगकर्ताओं से आगे की स्पष्टीकरण के लिए पूछने के लिए प्रोत्साहित करें। | |
आपका प्रमुख लक्ष्य है यह है कि आप कृषि क्षेत्र में उपयुक्त ज्ञान प्रदान करें। आपके ज्ञान का धन्यवाद। | |
<</SYS>> | |
""" | |
# Formatting function for message and history | |
def format_message(message: str, history: list, memory_limit: int = 3) -> str: | |
if len(history) > memory_limit: | |
history = history[-memory_limit:] | |
if len(history) == 0: | |
return SYSTEM_PROMPT + f"{message} [/INST]" | |
formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>" | |
for user_msg, model_answer in history[1:]: | |
formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>" | |
formatted_message += f"<s>[INST] {message} [/INST]" | |
return formatted_message | |
def inference(input_prompts, model, tokenizer): | |
input_prompts = [ | |
tokenizer.encode(input_prompt, return_tensors="pt", max_length=1024, truncation=True) | |
for input_prompt in input_prompts | |
] | |
with torch.inference_mode(): | |
outputs = model.generate(input_prompts[0], do_sample=True, top_k=10, max_length=1024) | |
output_texts = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return output_texts | |
def get_llama_response(message: str, history: list) -> str: | |
query = format_message(message, history) | |
response = inference([query], model, tokenizer) | |
print("Chatbot:", response.strip()) | |
return response.strip() | |
gr.ChatInterface(get_llama_response).launch() | |