SivaResearch commited on
Commit
4ca0de5
·
verified ·
1 Parent(s): c7eca26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -30
app.py CHANGED
@@ -24,43 +24,62 @@ SYSTEM_PROMPT = """<s>[INST] <<SYS>>
24
 
25
  आपका प्रमुख लक्ष्य है यह है कि आप कृषि क्षेत्र में उपयुक्त ज्ञान प्रदान करें। आपके ज्ञान का धन्यवाद।
26
  <</SYS>>
27
-
28
  """
29
 
30
- # Formatting function for message and history
31
- def format_message(message: str, history: list, memory_limit: int = 3) -> str:
32
- if len(history) > memory_limit:
33
- history = history[-memory_limit:]
34
-
35
- if len(history) == 0:
36
- return SYSTEM_PROMPT + f"{message} [/INST]"
37
-
38
- formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- for user_msg, model_answer in history[1:]:
41
- formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
42
 
43
- formatted_message += f"<s>[INST] {message} [/INST]"
 
 
 
 
44
 
45
- return formatted_message
46
 
47
- def inference(input_prompts, model, tokenizer):
48
- input_prompts = [
49
- tokenizer.encode(input_prompt, return_tensors="pt", max_length=1024, truncation=True)
50
- for input_prompt in input_prompts
51
- ]
52
 
53
- with torch.inference_mode():
54
- outputs = model.generate(input_prompts[0], do_sample=True, top_k=10, max_length=1024)
55
 
56
- output_texts = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
- return output_texts
58
 
59
- def get_llama_response(message: str, history: list) -> str:
60
- query = format_message(message, history)
61
- response = inference([query], model, tokenizer)
62
-
63
- print("Chatbot:", response.strip())
64
- return response.strip()
65
 
66
- gr.ChatInterface(get_llama_response).launch()
 
24
 
25
  आपका प्रमुख लक्ष्य है यह है कि आप कृषि क्षेत्र में उपयुक्त ज्ञान प्रदान करें। आपके ज्ञान का धन्यवाद।
26
  <</SYS>>
 
27
  """
28
 
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+
32
+ def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True, system_prompt="System: "):
33
+ formatted_text = ""
34
+ for message in messages:
35
+ if message["role"] == "system":
36
+ formatted_text += system_prompt + message["content"] + "\n"
37
+ elif message["role"] == "user":
38
+ formatted_text += "\n" + message["content"] + "\n"
39
+ elif message["role"] == "assistant":
40
+ formatted_text += "\n" + message["content"].strip() + eos + "\n"
41
+ else:
42
+ raise ValueError(
43
+ "Chat template only supports 'system', 'user', and 'assistant' roles. Invalid role: {}.".format(
44
+ message["role"]
45
+ )
46
+ )
47
+ formatted_text += "\n"
48
+ formatted_text = bos + formatted_text if add_bos else formatted_text
49
+ return formatted_text
50
+
51
+
52
+ def inference(input_prompts, model, tokenizer, system_prompt="System: "):
53
+ output_texts = []
54
+ for input_prompt in input_prompts:
55
+ formatted_query = create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False, system_prompt=system_prompt)
56
+ encodings = tokenizer(formatted_query, padding=True, return_tensors="pt")
57
+ encodings = encodings.to(device)
58
+
59
+ with torch.no_grad():
60
+ outputs = model.generate(encodings.input_ids, do_sample=False, max_length=250)
61
+
62
+ output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+ output_texts.append(output_text[len(input_prompt):])
64
+ return output_texts
65
 
 
 
66
 
67
+ examples = [
68
+ ["मुझे अपने करियर के बारे में सुझाव दो", "मैं कैसे अध्ययन कर सकता हूँ?"],
69
+ ["कृपया मुझे एक कहानी सुनाएं", "ताजमहल के बारे में कुछ बताएं"],
70
+ ["मेरा नाम क्या है?", "आपका पसंदीदा फिल्म कौन सी है?"],
71
+ ]
72
 
 
73
 
74
+ def get_llama_response(message: str, history: list, system_prompt=SYSTEM_PROMPT) -> str:
75
+ formatted_history = [{"role": "user", "content": hist} for hist in history]
76
+ formatted_message = {"role": "user", "content": message}
 
 
77
 
78
+ formatted_query = create_prompt_with_chat_format(formatted_history + [formatted_message], add_bos=False, system_prompt=system_prompt)
79
+ response = inference([formatted_query], model, tokenizer)
80
 
81
+ print("Chatbot:", response[0].strip())
82
+ return response[0].strip()
83
 
 
 
 
 
 
 
84
 
85
+ gr.ChatInterface(fn=get_llama_response, inputs=["text", "text", "text"], outputs="text").launch()