@@ -740,120 +740,6 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti
740
740
assertFalse (((String ) responseMap .get ("text" )).isEmpty ());
741
741
}
742
742
743
- public void testCohereClassifyModel () throws IOException , InterruptedException {
744
- // Skip test if key is null
745
- if (COHERE_KEY == null ) {
746
- return ;
747
- }
748
- String entity = "{\n "
749
- + " \" name\" : \" Cohere classify model Connector\" ,\n "
750
- + " \" description\" : \" The connector to public Cohere classify model service\" ,\n "
751
- + " \" version\" : 1,\n "
752
- + " \" client_config\" : {\n "
753
- + " \" max_connection\" : 20,\n "
754
- + " \" connection_timeout\" : 50000,\n "
755
- + " \" read_timeout\" : 50000\n "
756
- + " },\n "
757
- + " \" protocol\" : \" http\" ,\n "
758
- + " \" parameters\" : {\n "
759
- + " \" endpoint\" : \" api.cohere.ai\" ,\n "
760
- + " \" auth\" : \" API_Key\" ,\n "
761
- + " \" content_type\" : \" application/json\" ,\n "
762
- + " \" max_tokens\" : \" 20\" \n "
763
- + " },\n "
764
- + " \" credential\" : {\n "
765
- + " \" cohere_key\" : \" "
766
- + COHERE_KEY
767
- + "\" \n "
768
- + " },\n "
769
- + " \" actions\" : [\n "
770
- + " {\n "
771
- + " \" action_type\" : \" predict\" ,\n "
772
- + " \" method\" : \" POST\" ,\n "
773
- + " \" url\" : \" https://${parameters.endpoint}/v1/classify\" ,\n "
774
- + " \" headers\" : { \n "
775
- + " \" Authorization\" : \" Bearer ${credential.cohere_key}\" \n "
776
- + " },\n "
777
- + " \" request_body\" : \" { \\ \" inputs\\ \" : ${parameters.inputs}, \\ \" examples\\ \" : ${parameters.examples}, \\ \" truncate\\ \" : \\ \" END\\ \" }\" \n "
778
- + " }\n "
779
- + " ]\n "
780
- + "}" ;
781
- Response response = createConnector (entity );
782
- Map responseMap = parseResponseToMap (response );
783
- String connectorId = (String ) responseMap .get ("connector_id" );
784
- response = registerRemoteModel ("cohere classify model" , connectorId );
785
- responseMap = parseResponseToMap (response );
786
- String taskId = (String ) responseMap .get ("task_id" );
787
- waitForTask (taskId , MLTaskState .COMPLETED );
788
- response = getTask (taskId );
789
- responseMap = parseResponseToMap (response );
790
- String modelId = (String ) responseMap .get ("model_id" );
791
- response = deployRemoteModel (modelId );
792
- responseMap = parseResponseToMap (response );
793
- taskId = (String ) responseMap .get ("task_id" );
794
- waitForTask (taskId , MLTaskState .COMPLETED );
795
- String predictInput = "{\n "
796
- + " \" parameters\" : {\n "
797
- + " \" inputs\" : [\n "
798
- + " \" Confirm your email address\" ,\n "
799
- + " \" hey i need u to send some $\" \n "
800
- + " ],\n "
801
- + " \" examples\" : [\n "
802
- + " {\n "
803
- + " \" text\" : \" Dermatologists don't like her!\" ,\n "
804
- + " \" label\" : \" Spam\" \n "
805
- + " },\n "
806
- + " {\n "
807
- + " \" text\" : \" Hello, open to this?\" ,\n "
808
- + " \" label\" : \" Spam\" \n "
809
- + " },\n "
810
- + " {\n "
811
- + " \" text\" : \" I need help please wire me $1000 right now\" ,\n "
812
- + " \" label\" : \" Spam\" \n "
813
- + " },\n "
814
- + " {\n "
815
- + " \" text\" : \" Nice to know you ;)\" ,\n "
816
- + " \" label\" : \" Spam\" \n "
817
- + " },\n "
818
- + " {\n "
819
- + " \" text\" : \" Please help me?\" ,\n "
820
- + " \" label\" : \" Spam\" \n "
821
- + " },\n "
822
- + " {\n "
823
- + " \" text\" : \" Your parcel will be delivered today\" ,\n "
824
- + " \" label\" : \" Not spam\" \n "
825
- + " },\n "
826
- + " {\n "
827
- + " \" text\" : \" Review changes to our Terms and Conditions\" ,\n "
828
- + " \" label\" : \" Not spam\" \n "
829
- + " },\n "
830
- + " {\n "
831
- + " \" text\" : \" Weekly sync notes\" ,\n "
832
- + " \" label\" : \" Not spam\" \n "
833
- + " },\n "
834
- + " {\n "
835
- + " \" text\" : \" Re: Follow up from todays meeting\" ,\n "
836
- + " \" label\" : \" Not spam\" \n "
837
- + " },\n "
838
- + " {\n "
839
- + " \" text\" : \" Pre-read for tomorrow\" ,\n "
840
- + " \" label\" : \" Not spam\" \n "
841
- + " }\n "
842
- + " ]\n "
843
- + " }\n "
844
- + "}" ;
845
-
846
- response = predictRemoteModel (modelId , predictInput );
847
- responseMap = parseResponseToMap (response );
848
- List responseList = (List ) responseMap .get ("inference_results" );
849
- responseMap = (Map ) responseList .get (0 );
850
- responseList = (List ) responseMap .get ("output" );
851
- responseMap = (Map ) responseList .get (0 );
852
- responseMap = (Map ) responseMap .get ("dataAsMap" );
853
- responseList = (List ) responseMap .get ("classifications" );
854
- assertFalse (responseList .isEmpty ());
855
- }
856
-
857
743
public static Response createConnector (String input ) throws IOException {
858
744
try {
859
745
return TestHelper .makeRequest (client (), "POST" , "/_plugins/_ml/connectors/_create" , null , TestHelper .toHttpEntity (input ), null );
0 commit comments