GSoC-Super-Rapid-Annotator / src /text_processor.py
ManishThota's picture
Update src/text_processor.py
032d6c3 verified
raw
history blame
3.96 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import json
import warnings
from pydantic import BaseModel
from typing import Dict
import spaces
device = "cuda"
# Ignore warnings
warnings.filterwarnings(action='ignore')
# Set random seed
torch.random.manual_seed(0)
# Define model path and generation arguments
model_path = "microsoft/Phi-3-mini-4k-instruct"
generation_args = {
"max_new_tokens": 50,
"return_full_text": False,
"temperature": 0.1,
"do_sample": True
}
# Load the model and pipeline once and keep it in memory
def load_model_pipeline(model_path: str):
if not hasattr(load_model_pipeline, "pipe"):
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map=device,
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
load_model_pipeline.pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
return load_model_pipeline.pipe
# Initialize the pipeline and keep it in memory
pipe = load_model_pipeline(model_path)
# Generate logic from LLM output
@spaces.GPU(duration=50)
def generate_logic(llm_output: str) -> str:
prompt = f"""
Provide the response in json string for the below keys and context based on the description: '{llm_output}'.
Screen.interaction_yes: This field indicates whether there was an interaction of the person with a screen during the activity. A value of 1 means there was screen interaction (Yes), and a value of 0 means there was no screen interaction (No).
Hands.free: This field indicates whether the person's hands were free during the activity. A value of 1 means the person was not holding anything (Yes), indicating free hands. A value of 0 means the person was holding something (No), indicating the hands were not free.
Indoors: This field indicates whether the activity took place indoors. A value of 1 means the activity occurred inside a building or enclosed space (Yes), and a value of 0 means the activity took place outside (No).
Standing: This field indicates whether the person was standing during the activity. A value of 1 means the person was standing (Yes), and a value of 0 means the person was not standing (No).
"""
messages = [
{"role": "system", "content": "Please answer questions just based on this information: " + llm_output},
{"role": "user", "content": prompt},
]
response = pipe(messages, **generation_args)
generated_text = response[0]['generated_text']
# Extract JSON from the generated text
start_index = generated_text.find('{')
end_index = generated_text.rfind('}') + 1
json_str = generated_text[start_index:end_index]
# Log the generated JSON string for debugging
print(f"Generated JSON: {json_str}")
if not json_str.strip():
raise ValueError("Generated logic is empty or invalid JSON")
return json_str
# Pydantic model for structured output
class VideoAnalysis(BaseModel):
screen_interaction_yes: int
hands_free: int
indoors: int
standing: int
@classmethod
def from_llm_output(cls, generated_logic: str) -> 'VideoAnalysis':
try:
logic_dict = json.loads(generated_logic)
except json.JSONDecodeError as e:
raise ValueError(f"Error decoding JSON: {e}") from e
return cls(
screen_interaction_yes=logic_dict.get("Screen.interaction_yes", 0),
hands_free=logic_dict.get("Hands.free", 0),
indoors=logic_dict.get("Indoors", 0),
standing=logic_dict.get("Standing", 0)
)
# Main function to process LLM output
def process_description(description: str) -> Dict:
generated_logic = generate_logic(description)
structured_output = VideoAnalysis.from_llm_output(generated_logic)
return structured_output.dict()