Spaces:
Sleeping
Sleeping
fix: using hugging face
Browse files- services/model_handler.py +3 -20
services/model_handler.py
CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
|
|
4 |
from agno.agent import Agent
|
5 |
from agno.tools.arxiv import ArxivTools
|
6 |
from agno.tools.pubmed import PubmedTools
|
7 |
-
from agno.models.
|
8 |
|
9 |
MODEL_PATH = "facebook/opt-125m"
|
10 |
|
@@ -22,24 +22,7 @@ class ModelHandler:
|
|
22 |
def _initialize_model(self):
|
23 |
"""Initialize model and tokenizer"""
|
24 |
self.model, self.tokenizer = self._load_model()
|
25 |
-
|
26 |
-
class SimpleModel(BaseModel):
|
27 |
-
def __init__(self, model, tokenizer):
|
28 |
-
self.model = model
|
29 |
-
self.tokenizer = tokenizer
|
30 |
-
|
31 |
-
def generate(self, prompt, **kwargs):
|
32 |
-
inputs = self.tokenizer(prompt, return_tensors="pt")
|
33 |
-
outputs = self.model.generate(
|
34 |
-
inputs.input_ids,
|
35 |
-
max_length=512,
|
36 |
-
num_return_sequences=1,
|
37 |
-
temperature=0.7,
|
38 |
-
do_sample=True
|
39 |
-
)
|
40 |
-
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
41 |
-
|
42 |
-
base_model = SimpleModel(self.model, self.tokenizer)
|
43 |
|
44 |
self.translator = Agent(
|
45 |
name="Translator",
|
@@ -108,7 +91,7 @@ class ModelHandler:
|
|
108 |
],
|
109 |
add_references=True,
|
110 |
)
|
111 |
-
|
112 |
@staticmethod
|
113 |
@st.cache_resource
|
114 |
def _load_model():
|
|
|
4 |
from agno.agent import Agent
|
5 |
from agno.tools.arxiv import ArxivTools
|
6 |
from agno.tools.pubmed import PubmedTools
|
7 |
+
from agno.models.huggingface import HuggingFaceModel
|
8 |
|
9 |
MODEL_PATH = "facebook/opt-125m"
|
10 |
|
|
|
22 |
def _initialize_model(self):
|
23 |
"""Initialize model and tokenizer"""
|
24 |
self.model, self.tokenizer = self._load_model()
|
25 |
+
base_model = HuggingFaceModel(model_name=MODEL_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
self.translator = Agent(
|
28 |
name="Translator",
|
|
|
91 |
],
|
92 |
add_references=True,
|
93 |
)
|
94 |
+
|
95 |
@staticmethod
|
96 |
@st.cache_resource
|
97 |
def _load_model():
|