Joshua Sundance Bailey commited on
Commit
86262ce
·
1 Parent(s): cab77bb

max tokens slider

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -94,7 +94,21 @@ if provider_api_key:
94
  help="Higher values give more random results.",
95
  )
96
 
97
- chain = get_llm_chain(model, provider_api_key, system_prompt, temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  run_collector = RunCollectorCallbackHandler()
100
 
 
94
  help="Higher values give more random results.",
95
  )
96
 
97
+ max_tokens = st.sidebar.slider(
98
+ "Max Tokens",
99
+ min_value=0,
100
+ max_value=8000,
101
+ value=1000,
102
+ help="Higher values give longer results.",
103
+ )
104
+
105
+ chain = get_llm_chain(
106
+ model,
107
+ provider_api_key,
108
+ system_prompt,
109
+ temperature,
110
+ max_tokens,
111
+ )
112
 
113
  run_collector = RunCollectorCallbackHandler()
114
 
langchain-streamlit-demo/llm_stuff.py CHANGED
@@ -22,7 +22,8 @@ def get_memory() -> ConversationBufferMemory:
22
  def get_llm(
23
  model: str,
24
  provider_api_key: str,
25
- temperature,
 
26
  ):
27
  if model.startswith("gpt"):
28
  return ChatOpenAI(
@@ -30,6 +31,7 @@ def get_llm(
30
  openai_api_key=provider_api_key,
31
  temperature=temperature,
32
  streaming=True,
 
33
  )
34
  elif model.startswith("claude"):
35
  return ChatAnthropic(
@@ -37,6 +39,7 @@ def get_llm(
37
  anthropic_api_key=provider_api_key,
38
  temperature=temperature,
39
  streaming=True,
 
40
  )
41
  elif model.startswith("meta-llama"):
42
  return ChatAnyscale(
@@ -44,6 +47,7 @@ def get_llm(
44
  anyscale_api_key=provider_api_key,
45
  temperature=temperature,
46
  streaming=True,
 
47
  )
48
  else:
49
  raise NotImplementedError(f"Unknown model {model}")
@@ -54,6 +58,7 @@ def get_llm_chain(
54
  provider_api_key: str,
55
  system_prompt: str = _DEFAULT_SYSTEM_PROMPT,
56
  temperature: float = 0.7,
 
57
  ) -> LLMChain:
58
  """Return a basic LLMChain with memory."""
59
  prompt = ChatPromptTemplate.from_messages(
@@ -67,7 +72,7 @@ def get_llm_chain(
67
  ],
68
  ).partial(time=lambda: str(datetime.now()))
69
  memory = get_memory()
70
- llm = get_llm(model, provider_api_key, temperature)
71
  return LLMChain(prompt=prompt, llm=llm, memory=memory or get_memory())
72
 
73
 
 
22
  def get_llm(
23
  model: str,
24
  provider_api_key: str,
25
+ temperature: float,
26
+ max_tokens: int = 1000,
27
  ):
28
  if model.startswith("gpt"):
29
  return ChatOpenAI(
 
31
  openai_api_key=provider_api_key,
32
  temperature=temperature,
33
  streaming=True,
34
+ max_tokens=max_tokens,
35
  )
36
  elif model.startswith("claude"):
37
  return ChatAnthropic(
 
39
  anthropic_api_key=provider_api_key,
40
  temperature=temperature,
41
  streaming=True,
42
+ max_tokens_to_sample=max_tokens,
43
  )
44
  elif model.startswith("meta-llama"):
45
  return ChatAnyscale(
 
47
  anyscale_api_key=provider_api_key,
48
  temperature=temperature,
49
  streaming=True,
50
+ max_tokens=max_tokens,
51
  )
52
  else:
53
  raise NotImplementedError(f"Unknown model {model}")
 
58
  provider_api_key: str,
59
  system_prompt: str = _DEFAULT_SYSTEM_PROMPT,
60
  temperature: float = 0.7,
61
+ max_tokens: int = 1000,
62
  ) -> LLMChain:
63
  """Return a basic LLMChain with memory."""
64
  prompt = ChatPromptTemplate.from_messages(
 
72
  ],
73
  ).partial(time=lambda: str(datetime.now()))
74
  memory = get_memory()
75
+ llm = get_llm(model, provider_api_key, temperature, max_tokens)
76
  return LLMChain(prompt=prompt, llm=llm, memory=memory or get_memory())
77
 
78