99from skllm .models ._base .vectorizer import BaseVectorizer
1010from skllm .memory .base import IndexConstructor
1111from typing import Optional
12+ from model_constants import ANTHROPIC_CLAUDE_MODEL , OPENAI_EMBEDDING_MODEL
1213
1314
1415class FewShotClaudeClassifier (BaseFewShotClassifier , ClaudeClassifierMixin , SingleLabelMixin ):
1516 """Few-shot text classifier using Anthropic's Claude API for single-label classification tasks."""
1617
1718 def __init__ (
1819 self ,
19- model : str = "claude-3-haiku-20240307" ,
20+ model : str = ANTHROPIC_CLAUDE_MODEL ,
2021 default_label : str = "Random" ,
2122 prompt_template : Optional [str ] = None ,
2223 key : Optional [str ] = None ,
@@ -28,7 +29,7 @@ def __init__(
2829 Parameters
2930 ----------
3031 model : str, optional
31- model to use, by default "claude-3-haiku-20240307"
32+ model to use
3233 default_label : str, optional
3334 default label for failed prediction; if "Random" -> selects randomly based on class frequencies
3435 prompt_template : Optional[str], optional
@@ -52,7 +53,7 @@ class MultiLabelFewShotClaudeClassifier(
5253
5354 def __init__ (
5455 self ,
55- model : str = "claude-3-haiku-20240307" ,
56+ model : str = ANTHROPIC_CLAUDE_MODEL ,
5657 default_label : str = "Random" ,
5758 max_labels : Optional [int ] = 5 ,
5859 prompt_template : Optional [str ] = None ,
@@ -65,7 +66,7 @@ def __init__(
6566 Parameters
6667 ----------
6768 model : str, optional
68- model to use, by default "claude-3-haiku-20240307"
69+ model to use
6970 default_label : str, optional
7071 default label for failed prediction; if "Random" -> selects randomly based on class frequencies
7172 max_labels : Optional[int], optional
@@ -95,7 +96,7 @@ class DynamicFewShotClaudeClassifier(
9596
9697 def __init__ (
9798 self ,
98- model : str = "claude-3-haiku-20240307" ,
99+ model : str = ANTHROPIC_CLAUDE_MODEL ,
99100 default_label : str = "Random" ,
100101 prompt_template : Optional [str ] = None ,
101102 key : Optional [str ] = None ,
@@ -112,7 +113,7 @@ def __init__(
112113 Parameters
113114 ----------
114115 model : str, optional
115- model to use, by default "claude-3-haiku-20240307"
116+ model to use
116117 default_label : str, optional
117118 default label for failed prediction; if "Random" -> selects randomly based on class frequencies
118119 prompt_template : Optional[str], optional
@@ -129,7 +130,7 @@ def __init__(
129130 metric used for similarity search, by default "euclidean"
130131 """
131132 if vectorizer is None :
132- vectorizer = GPTVectorizer (model = "text-embedding-ada-002" , key = key )
133+ vectorizer = GPTVectorizer (model = OPENAI_EMBEDDING_MODEL , key = key )
133134 super ().__init__ (
134135 model = model ,
135136 default_label = default_label ,
0 commit comments